flash_causal_lm.py 36.1 KB
Newer Older
1
2
import math
import itertools
3
4
5
import torch
import torch.distributed

6
7
import numpy as np

8
9
from dataclasses import dataclass
from opentelemetry import trace
10
from transformers import PreTrainedTokenizerBase
11
from typing import Optional, Tuple, List, Type, Union, Dict
12
13
14
15
16
17
18
19
20

from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
21
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
22
from text_generation_server.utils.dist import MEMORY_FRACTION
23
24
25

tracer = trace.get_tracer(__name__)

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
BLOCK_SIZE = 16
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None


class CacheManager:
    def __init__(
        self,
        num_blocks: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        self.block_size = BLOCK_SIZE
42
        self.num_blocks = num_blocks
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

        element_size = torch.tensor([], dtype=dtype).element_size()
        x = self.block_size // element_size

        self.kv_cache = [
            (
                torch.empty(
                    (num_blocks, num_heads, head_size // x, self.block_size, x),
                    dtype=dtype,
                    device=device,
                ),
                torch.empty(
                    (num_blocks, num_heads, head_size, self.block_size),
                    dtype=dtype,
                    device=device,
                ),
            )
            for _ in range(num_layers)
        ]
        self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
        self.slots = torch.arange(
            0, num_blocks * self.block_size, dtype=torch.int32
        ).view(num_blocks, self.block_size)

    def allocate(self, batch: "FlashCausalLMBatch"):
        # Get free blocks indices by finding values in mask that are not set to 0
        free_block_indices = self.free_block_mask.nonzero()
        assert (
            len(free_block_indices) >= batch.blocks
        ), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks"

        # Slice by the number of required blocks
        block_indices = free_block_indices[: batch.blocks]
        block_indices = block_indices.flatten()

        # Padded block tables
        block_tables_tensor = torch.zeros(
            (len(batch), batch.max_blocks), dtype=torch.int32
        )

        # Allocate paged attention blocks
        cumulative_blocks = 0
        slots = []
        block_tables = []
        for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots):
            # Get allocated blocks for this sequence
            allocated_blocks = block_indices[
                cumulative_blocks : cumulative_blocks + needed_blocks
            ]
            # Get slots for the allocated blocks
            allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots]

            slots.append(allocated_slots)
            block_tables.append(allocated_blocks.tolist())
            block_tables_tensor[i, :needed_blocks] = allocated_blocks
            cumulative_blocks += needed_blocks

        batch.needed_blocks_slots = None
        batch.block_tables = block_tables
        batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device)
        batch.slots = torch.concat(slots).to(batch.input_ids.device)

        # Allocate the required number of blocks by setting the mask to 0
        self.free_block_mask[block_indices] = 0

    def free(self, block_indices: Optional[List[int]]):
        if block_indices is not None and block_indices:
            # Reset mask
            self.free_block_mask[block_indices] = 1

113
114
115
116
117

@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
118
119
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
120
121

    # Decoder values
122
123
124
    input_ids: torch.Tensor
    position_ids: torch.Tensor

125
126
127
128
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor
    # List of tuple of ints representing the number of blocks and slots needed by each sequence
    needed_blocks_slots: Optional[List[Tuple[int, int]]]

    # Set in prefill by the CacheManager
    # list of length b of list of length s_i // block_size
    block_tables: Optional[List[List[int]]]
    # tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
    block_tables_tensor: Optional[torch.Tensor]
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
    slots: Optional[torch.Tensor]

148
149
    max_seqlen: int

150
151
152
153
154
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

155
156
    # All tokens
    all_input_ids: List[List[int]]
157
    all_input_ids_tensor: torch.Tensor
158
159
160

    # Lengths of all generations present in the batch
    input_lengths: List[int]
161
    input_lengths_tensor: torch.Tensor
162
163
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
164
165

    # Generation helpers
166
    next_token_chooser: HeterogeneousNextTokenChooser
167
168
    stopping_criterias: List[StoppingCriteria]

169
170
171
172
    # Number of blocks in this batch
    blocks: int
    # Maximum number of blocks
    max_blocks: int
173

174
175
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
176
            id=self.batch_id,
177
            request_ids=[r.id for r in self.requests],
178
            size=len(self),
179
            max_tokens=self.blocks * BLOCK_SIZE,
180
181
182
183
184
185
186
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
187
        dtype: torch.dtype,
188
        device: torch.device,
189
    ) -> "FlashCausalLMBatch":
190
191
192
193
194
195
196
197
198
199
        batch_inputs = []
        max_truncation = 0
        for r in pb.requests:
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]

200
        position_ids = []
201
        cu_seqlen_prefill = [0]
202
203
204
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
205
206

        input_lengths = []
207
208
        prefix_offsets = []
        read_offsets = []
209
        all_input_ids = []
210
        requests_idx_mapping = {}
211

212
213
214
215
216
217
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

218
        next_token_chooser_parameters = []
219
220
221
222
        stopping_criterias = []

        # Cumulative length
        cumulative_length = 0
223
        cumulative_max_length = 0
224
        prefill_out_cumulative_length = 0
225

226
227
        blocks = 0
        max_seqlen = 0
228
        max_length = 0
229
        max_blocks = 0
230

231
        # Parse batch
232
233
234
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
235
236
237
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

238
            tokenized_input = tokenized_input[-r.truncate :]
239

240
241
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
242

243
            prefix_offsets.append(input_length - 5)
244
            read_offsets.append(input_length)
245

246
            all_input_ids.append(tokenized_input)
247
248

            # Position ids
249
250
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
251
252

            # Add cumulative lengths of all previous inputs
253
            cu_seqlen_prefill.append(cumulative_length + input_length)
254

255
            next_token_chooser_parameters.append(r.parameters)
256

257
258
259
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
260
            max_new_tokens = stopping_criteria.max_new_tokens
261
            stopping_criterias.append(stopping_criteria)
262

263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
            # Paged attention
            # Remove one as the first token des not have a past
            total_tokens = input_length + max_new_tokens - 1
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
            blocks += needed_blocks
            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

298
299
            # Update
            cumulative_length += input_length
300
301
302
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
303
304
305
306
307
            max_length = max(max_length, input_length + max_new_tokens)

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
            next_token_chooser_parameters, dtype, device
        )
308
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
309
310
311
312
313
314
315

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
316

317
318
319
320
321
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

322
323
324
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
325
            slot_indices = torch.cat(slot_indices)
326
327
328
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
329
            slot_indices = slot_indices[0]
330

331
332
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
333
        )
334

335
336
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
337
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
338
339
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
340
        )
341

342
343
        if all_prefill_logprobs:
            prefill_head_indices = None
344
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
345
        elif no_prefill_logprobs:
346
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
347
348
349
350
351
352
353
354
355
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )

356
357
358
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
359
            requests_idx_mapping=requests_idx_mapping,
360
361
            input_ids=input_ids,
            position_ids=position_ids,
362
            cu_seqlen_prefill=cu_seqlen_prefill,
363
364
365
366
367
368
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
369
            max_seqlen=max_seqlen,
370
371
372
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
373
            input_lengths=input_lengths,
374
            input_lengths_tensor=input_lengths_tensor,
375
376
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
377
            all_input_ids=all_input_ids,
378
379
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
380
            stopping_criterias=stopping_criterias,
381
382
            blocks=blocks,
            max_blocks=max_blocks,
383
384
        )

385
    @tracer.start_as_current_span("filter")
386
387
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
388
389
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
390
        if len(request_ids) == len(self):
391
392
            return self

393
        device = self.input_ids.device
394

395
396
397
        # New values after filtering
        requests_idx_mapping = {}

398
399
400
        # Used to index into tensors
        indices = []

401
402
403
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
404
405
        )

406
        # Create on CPU to only move to GPU once instead of at every copy
407
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
408
409
        max_seqlen = 0

410
        requests = []
411
412
        start_slots = []
        block_tables = []
413
414
        all_input_ids = []

415
        input_lengths = []
416
417
        prefix_offsets = []
        read_offsets = []
418

419
420
        stopping_criterias = []

421
422
423
424
425
        blocks = 0
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

426
427
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
428
            indices.append(idx)
429
430
431
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
432
433
434
435

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
436

437
438
439
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
440
441
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
442

443
444
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
445

446
            remaining_tokens = (
447
448
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
449

450
451
452
453
454
            request_block_table = self.block_tables[idx]
            blocks += len(request_block_table)
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

455
            # Copy to tensor (CPU)
456
            slot_indices[i] = cumulative_max_length + request_input_length - 1
457
458

            # Set slice
459
460
461
462
463
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
464
465
466
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
467

468
469
470
471
472
473
474
475
476
477
478
479
480
481
            max_blocks = max(max_blocks, len(request_block_table))

        global CACHE_MANAGER
        block_indices_to_free = []
        # Iterate on all requests
        for i, r in enumerate(self.requests):
            # Filter requests that are not part of the new batch
            if r.id not in requests_idx_mapping.keys():
                block_indices_to_free.extend(self.block_tables[i])
        # Free blocks
        CACHE_MANAGER.free(block_indices_to_free)
        # Needed to avoid dropping blocks when the batches will go out of scope
        self.block_tables = None

482
483
484
485
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
486
487
488
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
489
        next_token_chooser = self.next_token_chooser.filter(indices)
490
491

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
492

493
        # Move to GPU now that we have the whole tensor
494
        slot_indices = slot_indices.to(device)
495

496
497
498
499
500
501
        return FlashCausalLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
502
            cu_seqlen_prefill=None,
503
504
505
506
507
508
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
509
            max_seqlen=max_seqlen,
510
511
512
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
513
            input_lengths=input_lengths,
514
            input_lengths_tensor=input_lengths_tensor,
515
516
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
517
518
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
519
            next_token_chooser=next_token_chooser,
520
            stopping_criterias=stopping_criterias,
521
522
            blocks=blocks,
            max_blocks=max_blocks,
523
524
525
526
527
528
529
530
531
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        blocks = 0
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
            blocks += b.blocks
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
555
556
557

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
558
559
560
561
562
563
564
565
566
567
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
568
        )
569

570
571
        start_slots = []
        block_tables = []
572
573
574
        all_input_ids = []

        input_lengths = []
575
576
        prefix_offsets = []
        read_offsets = []
577

578
        next_token_chooser_parameters = []
579
580
        stopping_criterias = []

581
        # Cumulative length
582
        cumulative_batch_size = 0
583
        cumulative_slots = 0
584
585
586

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
587
588
589
590
591
592
593
594

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

595
596
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
597
598
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
599
600
601
602

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
603
604
605
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
            slots[slots_start_index:slots_end_index] = batch.slots
606

607
608
609
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
610

611
612
613
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
614

615
616
617
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
618
619
            all_input_ids.extend(batch.all_input_ids)

620
            input_lengths.extend(batch.input_lengths)
621
622
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
623

624
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
625
626
627
            stopping_criterias.extend(batch.stopping_criterias)

            # Update
628
            cumulative_batch_size += len(batch)
629
            cumulative_slots += len(batch.slots)
630

631
        start_slots = torch.concat(start_slots)
632

633
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
634
635
636
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
637
638
        )

639
640
641
        # Needed to avoid dropping blocks when the batches will go out of scope
        for b in batches:
            b.block_tables = None
642
            del b
643

644
645
646
        return FlashCausalLMBatch(
            batch_id=batches[0].batch_id,
            requests=requests,
647
            requests_idx_mapping=requests_idx_mapping,
648
649
            input_ids=input_ids,
            position_ids=position_ids,
650
            cu_seqlen_prefill=None,
651
652
653
654
655
656
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
657
            max_seqlen=max_seqlen,
658
659
660
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
661
            input_lengths=input_lengths,
662
            input_lengths_tensor=input_lengths_tensor,
663
664
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
665
666
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
667
            next_token_chooser=next_token_chooser,
668
            stopping_criterias=stopping_criterias,
669
670
            blocks=blocks,
            max_blocks=max_blocks,
671
672
        )

673
674
675
676
677
678
    def __del__(self):
        if self.block_tables is not None and self.block_tables:
            global CACHE_MANAGER
            # Free blocks
            CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))

679
680
681
682
683
684
685
    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
686
687
688
689
690
691
692
693
694
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        num_layers: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
695
    ):
696
697
698
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
699
700

        super(FlashCausalLM, self).__init__(
701
            model=model,
702
703
704
705
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
706
707
            rank=rank,
            world_size=world_size,
708
709
710
711
712
713
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

714
    def warmup(self, batch: FlashCausalLMBatch):
715
716
717
718
719
        global CACHE_MANAGER

        torch.cuda.empty_cache()
        try:
            CACHE_MANAGER = CacheManager(
720
                batch.blocks,
721
722
723
724
725
726
727
728
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
                self.dtype,
                self.device,
            )
            _, batch = self.generate_token(batch)
        except Exception as e:
729
            raise RuntimeError(
730
731
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
732
            ) from e
733
734
735

        torch.cuda.synchronize(self.device)

736
737
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
738
739
740
741
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

742
743
744
745
746
747
        total_free_memory, _ = torch.cuda.mem_get_info(self.device)
        total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory

        free_memory = max(
            0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
        )
748
749

        num_blocks = (
750
            int(free_memory // total_cache_size)
751
            # Add batch.blocks as we allocated it above, so it is included in the peak memory.
752
            + CACHE_MANAGER.num_blocks
753
754
755
        )

        del CACHE_MANAGER
756
        del batch
757
758
759
760
761
762
763
764
765
766
767
768
        torch.cuda.empty_cache()

        CACHE_MANAGER = CacheManager(
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
            self.dtype,
            self.device,
        )

        return int(num_blocks * BLOCK_SIZE)
769

770
771
    def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
        return self.tokenizer.decode(
772
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
773
774
775
776
777
778
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
779
        cu_seqlen_prefill: Optional[torch.Tensor],
780
781
782
        block_tables: torch.Tensor,
        slots: torch.Tensor,
        input_lengths: torch.Tensor,
783
        max_s: int,
784
        lm_head_indices: Optional[torch.Tensor] = None,
785
    ) -> Tuple[torch.Tensor, torch.Tensor]:
786
787
        global CACHE_MANAGER

788
789
790
791
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
792
            cu_seqlen_prefill=cu_seqlen_prefill,
793
794
795
796
            kv_cache=CACHE_MANAGER.kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
797
            max_s=max_s,
798
            lm_head_indices=lm_head_indices,
799
800
801
802
803
804
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
805
        prefill = batch.cu_seqlen_prefill is not None
806
        prefill_logprobs = batch.prefill_next_token_indices is not None
807

808
809
810
        if batch.needed_blocks_slots:
            # Allocate blocks to this batch
            CACHE_MANAGER.allocate(batch)
811

812
813
814
815
816
817
818
819
820
821
822
823
824
825
        try:
            out = self.forward(
                batch.input_ids,
                batch.position_ids,
                batch.cu_seqlen_prefill,
                batch.block_tables_tensor,
                batch.slots[batch.slot_indices],
                batch.input_lengths_tensor,
                batch.max_seqlen,
                batch.prefill_head_indices,
            )
        except Exception as e:
            del batch
            raise e
826

827
828
        if prefill:
            next_token_logits = (
829
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
830
831
832
833
834
835
836
837
            )
        else:
            next_token_logits = out

        next_input_ids, next_token_logprobs = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
        )

838
        if prefill:
839
            if len(batch) > 1 and prefill_logprobs:
840
841
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
842
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
843
844

            next_position_ids = batch.position_ids.new_empty(len(batch))
845
846
847
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
848
849
850
851
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

852
853
854
855
856
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
857
        stopped = True
858
859
860
861
862
863
864

        # Zipped iterator
        iterator = zip(
            batch.input_lengths,
            batch.all_input_ids,
        )

865
866
867
868
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

869
870
871
872
873
        # For each member of the batch
        for i, (
            input_length,
            all_input_ids,
        ) in enumerate(iterator):
874
            # Indexing metadata
875
876
877
            start_index = cumulative_length
            end_index = cumulative_length + input_length

878
            if prefill:
879
880
881
882
883
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

884
885
886
887
888
889
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
890
891
892
893
894
895
896
897
898
899
                if prefill_logprobs:
                    if len(batch) > 1:
                        prefill_tokens_indices[
                            out_start_index : out_end_index - 1
                        ] = batch.input_ids[start_index + 1 : start_index + out_length]
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
900

901
            batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
902
903
904
905
906
907

            cumulative_length += input_length

        # Set values in batch
        batch.input_ids = next_input_ids
        batch.position_ids = next_position_ids + 1
908
909
        batch.input_lengths_tensor += 1
        batch.slot_indices += 1
910

911
        if prefill and prefill_logprobs:
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
        next_token_ids = batch.input_ids.tolist()

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
928
929
            batch.prefix_offsets,
            batch.read_offsets,
930
931
            batch.stopping_criterias,
            batch.all_input_ids,
932
933
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
934
935
936
937
938
939
940
941
            next_token_ids,
            next_token_logprobs,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
942
943
            prefix_offset,
            read_offset,
944
945
            stopping_criteria,
            all_input_ids,
946
947
            do_sample,
            seed,
948
949
950
            next_token_id,
            next_token_logprob,
        ) in enumerate(iterator):
951
            # Append next token to all tokens
952
            all_input_ids.append(next_token_id)
953
954

            # Generated token
955
            next_token_text, prefix_offset, read_offset = self.decode_token(
956
                all_input_ids,
957
958
                prefix_offset,
                read_offset,
959
960
961
962
            )

            # Evaluate stopping criteria
            stop, reason = stopping_criteria(
963
                next_token_id,
964
965
966
                next_token_text,
            )

967
            if not stop:
968
                stopped = False
969

970
971
972
973
974
975
976
977
978
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
                    output_text = self.decode(
                        all_input_ids[-stopping_criteria.current_tokens :]
                    )
                    generated_text = GeneratedText(
979
980
981
982
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
983
984
985
986
987
                    )
                else:
                    generated_text = None

                # Prefill
988
989
990
991
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

992
993
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
994
                        out_start_index : out_end_index - 1
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    prefill_tokens = PrefillTokens(
                        prefill_token_ids, request_prefill_logprobs, prefill_texts
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id,
                    next_token_logprob,
                    next_token_text,
                    next_token_id in self.all_special_ids,
                    generated_text,
1016
1017
                )

1018
                generations.append(generation)
1019

1020
            # Update values
1021
            batch.input_lengths[i] = input_length + 1
1022
1023
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1024
1025
            batch.all_input_ids[i] = all_input_ids

1026
1027
1028
1029
1030
        if stopped:
            del batch
            # No need to return a batch if we know that all requests stopped
            return generations, None

1031
1032
1033
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1034
1035
        batch.max_seqlen = batch.max_seqlen + 1

1036
        return generations, batch