flash_causal_lm.py 36.1 KB
Newer Older
1
2
import math
import itertools
3
4
5
import torch
import torch.distributed

6
7
import numpy as np

8
from dataclasses import dataclass
9
from loguru import logger
10
from opentelemetry import trace
11
from transformers import PreTrainedTokenizerBase
12
from typing import Optional, Tuple, List, Type, Union, Dict
13
14
15
16
17
18
19
20
21

from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
22
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
23
24
25

tracer = trace.get_tracer(__name__)

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
BLOCK_SIZE = 16
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None


class CacheManager:
    def __init__(
        self,
        num_blocks: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        self.block_size = BLOCK_SIZE

        element_size = torch.tensor([], dtype=dtype).element_size()
        x = self.block_size // element_size

        self.kv_cache = [
            (
                torch.empty(
                    (num_blocks, num_heads, head_size // x, self.block_size, x),
                    dtype=dtype,
                    device=device,
                ),
                torch.empty(
                    (num_blocks, num_heads, head_size, self.block_size),
                    dtype=dtype,
                    device=device,
                ),
            )
            for _ in range(num_layers)
        ]
        self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
        self.slots = torch.arange(
            0, num_blocks * self.block_size, dtype=torch.int32
        ).view(num_blocks, self.block_size)

    def allocate(self, batch: "FlashCausalLMBatch"):
        # Get free blocks indices by finding values in mask that are not set to 0
        free_block_indices = self.free_block_mask.nonzero()
        assert (
            len(free_block_indices) >= batch.blocks
        ), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks"

        # Slice by the number of required blocks
        block_indices = free_block_indices[: batch.blocks]
        block_indices = block_indices.flatten()

        # Padded block tables
        block_tables_tensor = torch.zeros(
            (len(batch), batch.max_blocks), dtype=torch.int32
        )

        # Allocate paged attention blocks
        cumulative_blocks = 0
        slots = []
        block_tables = []
        for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots):
            # Get allocated blocks for this sequence
            allocated_blocks = block_indices[
                cumulative_blocks : cumulative_blocks + needed_blocks
            ]
            # Get slots for the allocated blocks
            allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots]

            slots.append(allocated_slots)
            block_tables.append(allocated_blocks.tolist())
            block_tables_tensor[i, :needed_blocks] = allocated_blocks
            cumulative_blocks += needed_blocks

        batch.needed_blocks_slots = None
        batch.block_tables = block_tables
        batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device)
        batch.slots = torch.concat(slots).to(batch.input_ids.device)

        # Allocate the required number of blocks by setting the mask to 0
        self.free_block_mask[block_indices] = 0

    def free(self, block_indices: Optional[List[int]]):
        if block_indices is not None and block_indices:
            # Reset mask
            self.free_block_mask[block_indices] = 1

112
113
114
115
116

@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
117
118
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
119
120

    # Decoder values
121
122
123
    input_ids: torch.Tensor
    position_ids: torch.Tensor

124
125
126
127
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor
    # List of tuple of ints representing the number of blocks and slots needed by each sequence
    needed_blocks_slots: Optional[List[Tuple[int, int]]]

    # Set in prefill by the CacheManager
    # list of length b of list of length s_i // block_size
    block_tables: Optional[List[List[int]]]
    # tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
    block_tables_tensor: Optional[torch.Tensor]
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
    slots: Optional[torch.Tensor]

147
148
    max_seqlen: int

149
150
151
152
153
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

154
155
    # All tokens
    all_input_ids: List[List[int]]
156
    all_input_ids_tensor: torch.Tensor
157
158
159

    # Lengths of all generations present in the batch
    input_lengths: List[int]
160
    input_lengths_tensor: torch.Tensor
161
162
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
163
164

    # Generation helpers
165
    next_token_chooser: HeterogeneousNextTokenChooser
166
167
    stopping_criterias: List[StoppingCriteria]

168
169
170
171
    # Number of blocks in this batch
    blocks: int
    # Maximum number of blocks
    max_blocks: int
172

173
174
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
175
            id=self.batch_id,
176
            request_ids=[r.id for r in self.requests],
177
            size=len(self),
178
            max_tokens=self.blocks * BLOCK_SIZE,
179
180
181
182
183
184
185
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
186
        dtype: torch.dtype,
187
        device: torch.device,
188
    ) -> "FlashCausalLMBatch":
189
190
191
192
193
194
195
196
197
198
        batch_inputs = []
        max_truncation = 0
        for r in pb.requests:
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]

199
        position_ids = []
200
        cu_seqlen_prefill = [0]
201
202
203
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
204
205

        input_lengths = []
206
207
        prefix_offsets = []
        read_offsets = []
208
        all_input_ids = []
209
        requests_idx_mapping = {}
210

211
212
213
214
215
216
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

217
        next_token_chooser_parameters = []
218
219
220
221
        stopping_criterias = []

        # Cumulative length
        cumulative_length = 0
222
        cumulative_max_length = 0
223
        prefill_out_cumulative_length = 0
224

225
226
        blocks = 0
        max_seqlen = 0
227
        max_length = 0
228
        max_blocks = 0
229

230
        # Parse batch
231
232
233
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
234
235
236
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

237
            tokenized_input = tokenized_input[-r.truncate :]
238

239
240
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
241

242
            prefix_offsets.append(input_length - 5)
243
            read_offsets.append(input_length)
244

245
            all_input_ids.append(tokenized_input)
246
247

            # Position ids
248
249
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
250
251

            # Add cumulative lengths of all previous inputs
252
            cu_seqlen_prefill.append(cumulative_length + input_length)
253

254
            next_token_chooser_parameters.append(r.parameters)
255

256
257
258
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
259
            max_new_tokens = stopping_criteria.max_new_tokens
260
            stopping_criterias.append(stopping_criteria)
261

262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
            # Paged attention
            # Remove one as the first token des not have a past
            total_tokens = input_length + max_new_tokens - 1
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
            blocks += needed_blocks
            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

297
298
            # Update
            cumulative_length += input_length
299
300
301
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
302
303
304
305
306
            max_length = max(max_length, input_length + max_new_tokens)

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
            next_token_chooser_parameters, dtype, device
        )
307
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
308
309
310
311
312
313
314

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
315

316
317
318
319
320
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

321
322
323
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
324
            slot_indices = torch.cat(slot_indices)
325
326
327
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
328
            slot_indices = slot_indices[0]
329

330
331
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
332
        )
333

334
335
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
336
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
337
338
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
339
        )
340

341
342
        if all_prefill_logprobs:
            prefill_head_indices = None
343
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
344
        elif no_prefill_logprobs:
345
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
346
347
348
349
350
351
352
353
354
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )

355
356
357
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
358
            requests_idx_mapping=requests_idx_mapping,
359
360
            input_ids=input_ids,
            position_ids=position_ids,
361
            cu_seqlen_prefill=cu_seqlen_prefill,
362
363
364
365
366
367
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
368
            max_seqlen=max_seqlen,
369
370
371
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
372
            input_lengths=input_lengths,
373
            input_lengths_tensor=input_lengths_tensor,
374
375
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
376
            all_input_ids=all_input_ids,
377
378
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
379
            stopping_criterias=stopping_criterias,
380
381
            blocks=blocks,
            max_blocks=max_blocks,
382
383
        )

384
    @tracer.start_as_current_span("filter")
385
386
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
387
388
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
389
        if len(request_ids) == len(self):
390
391
            return self

392
        device = self.input_ids.device
393

394
395
396
        # New values after filtering
        requests_idx_mapping = {}

397
398
399
        # Used to index into tensors
        indices = []

400
401
402
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
403
404
        )

405
        # Create on CPU to only move to GPU once instead of at every copy
406
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
407
408
        max_seqlen = 0

409
        requests = []
410
411
        start_slots = []
        block_tables = []
412
413
        all_input_ids = []

414
        input_lengths = []
415
416
        prefix_offsets = []
        read_offsets = []
417

418
419
        stopping_criterias = []

420
421
422
423
424
        blocks = 0
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

425
426
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
427
            indices.append(idx)
428
429
430
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
431
432
433
434

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
435

436
437
438
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
439
440
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
441

442
443
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
444

445
            remaining_tokens = (
446
447
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
448

449
450
451
452
453
            request_block_table = self.block_tables[idx]
            blocks += len(request_block_table)
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

454
            # Copy to tensor (CPU)
455
            slot_indices[i] = cumulative_max_length + request_input_length - 1
456
457

            # Set slice
458
459
460
461
462
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
463
464
465
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
466

467
468
469
470
471
472
473
474
475
476
477
478
479
480
            max_blocks = max(max_blocks, len(request_block_table))

        global CACHE_MANAGER
        block_indices_to_free = []
        # Iterate on all requests
        for i, r in enumerate(self.requests):
            # Filter requests that are not part of the new batch
            if r.id not in requests_idx_mapping.keys():
                block_indices_to_free.extend(self.block_tables[i])
        # Free blocks
        CACHE_MANAGER.free(block_indices_to_free)
        # Needed to avoid dropping blocks when the batches will go out of scope
        self.block_tables = None

481
482
483
484
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
485
486
487
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
488
        next_token_chooser = self.next_token_chooser.filter(indices)
489
490

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
491

492
        # Move to GPU now that we have the whole tensor
493
        slot_indices = slot_indices.to(device)
494

495
496
497
498
499
500
        return FlashCausalLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
501
            cu_seqlen_prefill=None,
502
503
504
505
506
507
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
508
            max_seqlen=max_seqlen,
509
510
511
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
512
            input_lengths=input_lengths,
513
            input_lengths_tensor=input_lengths_tensor,
514
515
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
516
517
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
518
            next_token_chooser=next_token_chooser,
519
            stopping_criterias=stopping_criterias,
520
521
            blocks=blocks,
            max_blocks=max_blocks,
522
523
524
525
526
527
528
529
530
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
        blocks = 0
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
            blocks += b.blocks
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
554
555
556

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
557
558
559
560
561
562
563
564
565
566
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
567
        )
568

569
570
        start_slots = []
        block_tables = []
571
572
573
        all_input_ids = []

        input_lengths = []
574
575
        prefix_offsets = []
        read_offsets = []
576

577
        next_token_chooser_parameters = []
578
579
        stopping_criterias = []

580
        # Cumulative length
581
        cumulative_batch_size = 0
582
        cumulative_slots = 0
583
584
585

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
586
587
588
589
590
591
592
593

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

594
595
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
596
597
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
598
599
600
601

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
602
603
604
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
            slots[slots_start_index:slots_end_index] = batch.slots
605

606
607
608
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
609

610
611
612
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
613

614
615
616
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
617
618
            all_input_ids.extend(batch.all_input_ids)

619
            input_lengths.extend(batch.input_lengths)
620
621
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
622

623
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
624
625
626
            stopping_criterias.extend(batch.stopping_criterias)

            # Update
627
            cumulative_batch_size += len(batch)
628
            cumulative_slots += len(batch.slots)
629

630
        start_slots = torch.concat(start_slots)
631

632
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
633
634
635
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
636
637
        )

638
639
640
        # Needed to avoid dropping blocks when the batches will go out of scope
        for b in batches:
            b.block_tables = None
641
            del b
642

643
644
645
        return FlashCausalLMBatch(
            batch_id=batches[0].batch_id,
            requests=requests,
646
            requests_idx_mapping=requests_idx_mapping,
647
648
            input_ids=input_ids,
            position_ids=position_ids,
649
            cu_seqlen_prefill=None,
650
651
652
653
654
655
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
656
            max_seqlen=max_seqlen,
657
658
659
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
660
            input_lengths=input_lengths,
661
            input_lengths_tensor=input_lengths_tensor,
662
663
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
664
665
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
666
            next_token_chooser=next_token_chooser,
667
            stopping_criterias=stopping_criterias,
668
669
            blocks=blocks,
            max_blocks=max_blocks,
670
671
        )

672
673
674
675
676
677
    def __del__(self):
        if self.block_tables is not None and self.block_tables:
            global CACHE_MANAGER
            # Free blocks
            CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))

678
679
680
681
682
683
684
    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
685
686
687
688
689
690
691
692
693
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        num_layers: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
694
    ):
695
696
697
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
698
699

        super(FlashCausalLM, self).__init__(
700
            model=model,
701
702
703
704
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
705
706
            rank=rank,
            world_size=world_size,
707
708
709
710
711
712
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

713
    def warmup(self, batch: FlashCausalLMBatch):
714
715
716
        global CACHE_MANAGER

        torch.cuda.empty_cache()
717
        torch.cuda.reset_peak_memory_stats(self.device)
718
719
        try:
            CACHE_MANAGER = CacheManager(
720
                batch.blocks,
721
722
723
724
725
726
727
728
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
                self.dtype,
                self.device,
            )
            _, batch = self.generate_token(batch)
        except Exception as e:
729
            raise RuntimeError(
730
731
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
732
            ) from e
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753

        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.cuda.synchronize(self.device)
        peak_memory = torch.cuda.max_memory_reserved(self.device)

        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

        total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory

        # 0.98 to add some wiggle room
        num_blocks = (
            int((total_gpu_memory * 0.98 - peak_memory) // total_cache_size)
            # Add batch.blocks as we allocated it above, so it is included in the peak memory.
            + batch.blocks
        )

        del CACHE_MANAGER
754
        del batch
755
756
757
758
759
760
761
762
763
764
765
766
        torch.cuda.empty_cache()

        CACHE_MANAGER = CacheManager(
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
            self.dtype,
            self.device,
        )

        return int(num_blocks * BLOCK_SIZE)
767

768
769
    def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
        return self.tokenizer.decode(
770
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
771
772
773
774
775
776
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
777
        cu_seqlen_prefill: Optional[torch.Tensor],
778
779
780
        block_tables: torch.Tensor,
        slots: torch.Tensor,
        input_lengths: torch.Tensor,
781
        max_s: int,
782
        lm_head_indices: Optional[torch.Tensor] = None,
783
    ) -> Tuple[torch.Tensor, torch.Tensor]:
784
785
        global CACHE_MANAGER

786
787
788
789
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
790
            cu_seqlen_prefill=cu_seqlen_prefill,
791
792
793
794
            kv_cache=CACHE_MANAGER.kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
795
            max_s=max_s,
796
            lm_head_indices=lm_head_indices,
797
798
799
800
801
802
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
803
        prefill = batch.cu_seqlen_prefill is not None
804
        prefill_logprobs = batch.prefill_next_token_indices is not None
805

806
807
808
        if batch.needed_blocks_slots:
            # Allocate blocks to this batch
            CACHE_MANAGER.allocate(batch)
809

810
811
812
813
814
815
816
817
818
819
820
821
822
823
        try:
            out = self.forward(
                batch.input_ids,
                batch.position_ids,
                batch.cu_seqlen_prefill,
                batch.block_tables_tensor,
                batch.slots[batch.slot_indices],
                batch.input_lengths_tensor,
                batch.max_seqlen,
                batch.prefill_head_indices,
            )
        except Exception as e:
            del batch
            raise e
824

825
826
        if prefill:
            next_token_logits = (
827
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
828
829
830
831
832
833
834
835
            )
        else:
            next_token_logits = out

        next_input_ids, next_token_logprobs = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
        )

836
        if prefill:
837
            if len(batch) > 1 and prefill_logprobs:
838
839
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
840
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
841
842

            next_position_ids = batch.position_ids.new_empty(len(batch))
843
844
845
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
846
847
848
849
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

850
851
852
853
854
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
855
        stopped = True
856
857
858
859
860
861
862

        # Zipped iterator
        iterator = zip(
            batch.input_lengths,
            batch.all_input_ids,
        )

863
864
865
866
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

867
868
869
870
871
        # For each member of the batch
        for i, (
            input_length,
            all_input_ids,
        ) in enumerate(iterator):
872
            # Indexing metadata
873
874
875
            start_index = cumulative_length
            end_index = cumulative_length + input_length

876
            if prefill:
877
878
879
880
881
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

882
883
884
885
886
887
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
888
889
890
891
892
893
894
895
896
897
                if prefill_logprobs:
                    if len(batch) > 1:
                        prefill_tokens_indices[
                            out_start_index : out_end_index - 1
                        ] = batch.input_ids[start_index + 1 : start_index + out_length]
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
898

899
            batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
900
901
902
903
904
905

            cumulative_length += input_length

        # Set values in batch
        batch.input_ids = next_input_ids
        batch.position_ids = next_position_ids + 1
906
907
        batch.input_lengths_tensor += 1
        batch.slot_indices += 1
908

909
        if prefill and prefill_logprobs:
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
        next_token_ids = batch.input_ids.tolist()

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
926
927
            batch.prefix_offsets,
            batch.read_offsets,
928
929
            batch.stopping_criterias,
            batch.all_input_ids,
930
931
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
932
933
934
935
936
937
938
939
            next_token_ids,
            next_token_logprobs,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
940
941
            prefix_offset,
            read_offset,
942
943
            stopping_criteria,
            all_input_ids,
944
945
            do_sample,
            seed,
946
947
948
            next_token_id,
            next_token_logprob,
        ) in enumerate(iterator):
949
            # Append next token to all tokens
950
            all_input_ids.append(next_token_id)
951
952

            # Generated token
953
            next_token_text, prefix_offset, read_offset = self.decode_token(
954
                all_input_ids,
955
956
                prefix_offset,
                read_offset,
957
958
959
960
            )

            # Evaluate stopping criteria
            stop, reason = stopping_criteria(
961
                next_token_id,
962
963
964
                next_token_text,
            )

965
            if not stop:
966
                stopped = False
967

968
969
970
971
972
973
974
975
976
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
                    output_text = self.decode(
                        all_input_ids[-stopping_criteria.current_tokens :]
                    )
                    generated_text = GeneratedText(
977
978
979
980
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
981
982
983
984
985
                    )
                else:
                    generated_text = None

                # Prefill
986
987
988
989
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

990
991
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
992
                        out_start_index : out_end_index - 1
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    prefill_tokens = PrefillTokens(
                        prefill_token_ids, request_prefill_logprobs, prefill_texts
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id,
                    next_token_logprob,
                    next_token_text,
                    next_token_id in self.all_special_ids,
                    generated_text,
1014
1015
                )

1016
                generations.append(generation)
1017

1018
            # Update values
1019
            batch.input_lengths[i] = input_length + 1
1020
1021
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1022
1023
            batch.all_input_ids[i] = all_input_ids

1024
1025
1026
1027
1028
        if stopped:
            del batch
            # No need to return a batch if we know that all requests stopped
            return generations, None

1029
1030
1031
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1032
1033
        batch.max_seqlen = batch.max_seqlen + 1

1034
        return generations, batch