flash_causal_lm.py 35.1 KB
Newer Older
1
2
import math
import itertools
3
4
5
import torch
import torch.distributed

6
7
import numpy as np

8
from dataclasses import dataclass
9
from loguru import logger
10
from opentelemetry import trace
11
from transformers import PreTrainedTokenizerBase
12
from typing import Optional, Tuple, List, Type, Union, Dict
13
14
15
16
17
18
19
20
21

from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
22
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
23
24
25

tracer = trace.get_tracer(__name__)

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
BLOCK_SIZE = 16
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None


class CacheManager:
    def __init__(
        self,
        num_blocks: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        self.block_size = BLOCK_SIZE

        element_size = torch.tensor([], dtype=dtype).element_size()
        x = self.block_size // element_size

        self.kv_cache = [
            (
                torch.empty(
                    (num_blocks, num_heads, head_size // x, self.block_size, x),
                    dtype=dtype,
                    device=device,
                ),
                torch.empty(
                    (num_blocks, num_heads, head_size, self.block_size),
                    dtype=dtype,
                    device=device,
                ),
            )
            for _ in range(num_layers)
        ]
        self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
        self.slots = torch.arange(
            0, num_blocks * self.block_size, dtype=torch.int32
        ).view(num_blocks, self.block_size)

    def allocate(self, batch: "FlashCausalLMBatch"):
        # Get free blocks indices by finding values in mask that are not set to 0
        free_block_indices = self.free_block_mask.nonzero()
        assert (
            len(free_block_indices) >= batch.blocks
        ), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks"

        # Slice by the number of required blocks
        block_indices = free_block_indices[: batch.blocks]
        block_indices = block_indices.flatten()

        # Padded block tables
        block_tables_tensor = torch.zeros(
            (len(batch), batch.max_blocks), dtype=torch.int32
        )

        # Allocate paged attention blocks
        cumulative_blocks = 0
        slots = []
        block_tables = []
        for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots):
            # Get allocated blocks for this sequence
            allocated_blocks = block_indices[
                cumulative_blocks : cumulative_blocks + needed_blocks
            ]
            # Get slots for the allocated blocks
            allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots]

            slots.append(allocated_slots)
            block_tables.append(allocated_blocks.tolist())
            block_tables_tensor[i, :needed_blocks] = allocated_blocks
            cumulative_blocks += needed_blocks

        batch.needed_blocks_slots = None
        batch.block_tables = block_tables
        batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device)
        batch.slots = torch.concat(slots).to(batch.input_ids.device)

        # Allocate the required number of blocks by setting the mask to 0
        self.free_block_mask[block_indices] = 0

    def free(self, block_indices: Optional[List[int]]):
        if block_indices is not None and block_indices:
            # Reset mask
            self.free_block_mask[block_indices] = 1

112
113
114
115
116

@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
117
118
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
119
120

    # Decoder values
121
122
123
    input_ids: torch.Tensor
    position_ids: torch.Tensor

124
125
126
127
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor
    # List of tuple of ints representing the number of blocks and slots needed by each sequence
    needed_blocks_slots: Optional[List[Tuple[int, int]]]

    # Set in prefill by the CacheManager
    # list of length b of list of length s_i // block_size
    block_tables: Optional[List[List[int]]]
    # tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
    block_tables_tensor: Optional[torch.Tensor]
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
    slots: Optional[torch.Tensor]

147
148
    max_seqlen: int

149
150
151
152
153
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

154
155
    # All tokens
    all_input_ids: List[List[int]]
156
    all_input_ids_tensor: torch.Tensor
157
158
159

    # Lengths of all generations present in the batch
    input_lengths: List[int]
160
    input_lengths_tensor: torch.Tensor
161
162
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
163
164

    # Generation helpers
165
    next_token_chooser: HeterogeneousNextTokenChooser
166
167
    stopping_criterias: List[StoppingCriteria]

168
169
170
171
    # Number of blocks in this batch
    blocks: int
    # Maximum number of blocks
    max_blocks: int
172

173
174
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
175
            id=self.batch_id,
176
            request_ids=[r.id for r in self.requests],
177
            size=len(self),
178
            max_tokens=self.blocks * BLOCK_SIZE,
179
180
181
182
183
184
185
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
186
        dtype: torch.dtype,
187
        device: torch.device,
188
    ) -> "FlashCausalLMBatch":
189
190
191
192
193
194
195
196
197
198
        batch_inputs = []
        max_truncation = 0
        for r in pb.requests:
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]

199
        position_ids = []
200
        cu_seqlen_prefill = [0]
201
202
203
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
204
205

        input_lengths = []
206
207
        prefix_offsets = []
        read_offsets = []
208
        all_input_ids = []
209
        requests_idx_mapping = {}
210

211
212
213
214
215
216
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

217
        next_token_chooser_parameters = []
218
219
220
221
        stopping_criterias = []

        # Cumulative length
        cumulative_length = 0
222
        cumulative_max_length = 0
223
        prefill_out_cumulative_length = 0
224

225
226
        blocks = 0
        max_seqlen = 0
227
        max_length = 0
228
        max_blocks = 0
229

230
        # Parse batch
231
232
233
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
234
235
236
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

237
            tokenized_input = tokenized_input[-r.truncate :]
238

239
240
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
241

242
            prefix_offsets.append(input_length - 5)
243
            read_offsets.append(input_length)
244

245
            all_input_ids.append(tokenized_input)
246
247

            # Position ids
248
249
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
250
251

            # Add cumulative lengths of all previous inputs
252
            cu_seqlen_prefill.append(cumulative_length + input_length)
253

254
            next_token_chooser_parameters.append(r.parameters)
255

256
257
258
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
259
            max_new_tokens = stopping_criteria.max_new_tokens
260
            stopping_criterias.append(stopping_criteria)
261

262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
            # Paged attention
            # Remove one as the first token des not have a past
            total_tokens = input_length + max_new_tokens - 1
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
            blocks += needed_blocks
            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

297
298
            # Update
            cumulative_length += input_length
299
300
301
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
302
303
304
305
306
            max_length = max(max_length, input_length + max_new_tokens)

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
            next_token_chooser_parameters, dtype, device
        )
307
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
308
309
310
311
312
313
314

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
315

316
317
318
319
320
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

321
322
323
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
324
            slot_indices = torch.cat(slot_indices)
325
326
327
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
328
            slot_indices = slot_indices[0]
329

330
331
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
332
        )
333

334
335
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
336
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
337
338
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
339
        )
340

341
342
        if all_prefill_logprobs:
            prefill_head_indices = None
343
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
344
        elif no_prefill_logprobs:
345
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
346
347
348
349
350
351
352
353
354
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )

355
356
357
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
358
            requests_idx_mapping=requests_idx_mapping,
359
360
            input_ids=input_ids,
            position_ids=position_ids,
361
            cu_seqlen_prefill=cu_seqlen_prefill,
362
363
364
365
366
367
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
368
            max_seqlen=max_seqlen,
369
370
371
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
372
            input_lengths=input_lengths,
373
            input_lengths_tensor=input_lengths_tensor,
374
375
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
376
            all_input_ids=all_input_ids,
377
378
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
379
            stopping_criterias=stopping_criterias,
380
381
            blocks=blocks,
            max_blocks=max_blocks,
382
383
        )

384
    @tracer.start_as_current_span("filter")
385
386
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
387
388
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
389
        if len(request_ids) == len(self):
390
391
            return self

392
        device = self.input_ids.device
393

394
395
396
        # New values after filtering
        requests_idx_mapping = {}

397
398
399
        # Used to index into tensors
        indices = []

400
401
402
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
403
404
        )

405
        # Create on CPU to only move to GPU once instead of at every copy
406
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
407
408
        max_seqlen = 0

409
        requests = []
410
411
        start_slots = []
        block_tables = []
412
413
        all_input_ids = []

414
        input_lengths = []
415
416
        prefix_offsets = []
        read_offsets = []
417

418
419
        stopping_criterias = []

420
421
422
423
424
        blocks = 0
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

425
426
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
427
            indices.append(idx)
428
429
430
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
431
432
433
434

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
435

436
437
438
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
439
440
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
441

442
443
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
444

445
            remaining_tokens = (
446
447
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
448

449
450
451
452
453
            request_block_table = self.block_tables[idx]
            blocks += len(request_block_table)
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

454
            # Copy to tensor (CPU)
455
            slot_indices[i] = cumulative_max_length + request_input_length - 1
456
457

            # Set slice
458
459
460
461
462
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
463
464
465
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
466

467
468
469
470
471
472
473
474
475
476
477
478
479
480
            max_blocks = max(max_blocks, len(request_block_table))

        global CACHE_MANAGER
        block_indices_to_free = []
        # Iterate on all requests
        for i, r in enumerate(self.requests):
            # Filter requests that are not part of the new batch
            if r.id not in requests_idx_mapping.keys():
                block_indices_to_free.extend(self.block_tables[i])
        # Free blocks
        CACHE_MANAGER.free(block_indices_to_free)
        # Needed to avoid dropping blocks when the batches will go out of scope
        self.block_tables = None

481
482
483
484
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
485
486
487
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
488
        next_token_chooser = self.next_token_chooser.filter(indices)
489
490

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
491

492
        # Move to GPU now that we have the whole tensor
493
        slot_indices = slot_indices.to(device)
494

495
496
497
498
499
500
        return FlashCausalLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
501
            cu_seqlen_prefill=None,
502
503
504
505
506
507
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
508
            max_seqlen=max_seqlen,
509
510
511
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
512
            input_lengths=input_lengths,
513
            input_lengths_tensor=input_lengths_tensor,
514
515
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
516
517
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
518
            next_token_chooser=next_token_chooser,
519
            stopping_criterias=stopping_criterias,
520
521
            blocks=blocks,
            max_blocks=max_blocks,
522
523
524
525
526
527
528
529
530
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
        blocks = 0
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
            blocks += b.blocks
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
554
555
556

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
557
558
559
560
561
562
563
564
565
566
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
567
        )
568

569
570
        start_slots = []
        block_tables = []
571
572
573
        all_input_ids = []

        input_lengths = []
574
575
        prefix_offsets = []
        read_offsets = []
576

577
        next_token_chooser_parameters = []
578
579
        stopping_criterias = []

580
        # Cumulative length
581
        cumulative_batch_size = 0
582
        cumulative_slots = 0
583
584
585

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
586
587
588
589
590
591
592
593

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

594
595
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
596
597
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
598
599
600
601

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
602
603
604
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
            slots[slots_start_index:slots_end_index] = batch.slots
605

606
607
608
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
609

610
611
612
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
613

614
615
616
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
617
618
            all_input_ids.extend(batch.all_input_ids)

619
            input_lengths.extend(batch.input_lengths)
620
621
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
622

623
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
624
625
626
            stopping_criterias.extend(batch.stopping_criterias)

            # Update
627
            cumulative_batch_size += len(batch)
628
            cumulative_slots += len(batch.slots)
629

630
        start_slots = torch.concat(start_slots)
631

632
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
633
634
635
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
636
637
        )

638
639
640
        # Needed to avoid dropping blocks when the batches will go out of scope
        for b in batches:
            b.block_tables = None
641
642
            del b
        torch.cuda.empty_cache()
643

644
645
646
        return FlashCausalLMBatch(
            batch_id=batches[0].batch_id,
            requests=requests,
647
            requests_idx_mapping=requests_idx_mapping,
648
649
            input_ids=input_ids,
            position_ids=position_ids,
650
            cu_seqlen_prefill=None,
651
652
653
654
655
656
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
657
            max_seqlen=max_seqlen,
658
659
660
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
661
            input_lengths=input_lengths,
662
            input_lengths_tensor=input_lengths_tensor,
663
664
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
665
666
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
667
            next_token_chooser=next_token_chooser,
668
            stopping_criterias=stopping_criterias,
669
670
            blocks=blocks,
            max_blocks=max_blocks,
671
672
        )

673
674
675
676
677
678
    def __del__(self):
        if self.block_tables is not None and self.block_tables:
            global CACHE_MANAGER
            # Free blocks
            CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))

679
680
681
682
683
684
685
    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
686
687
688
689
690
691
692
693
694
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        num_layers: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
695
    ):
696
697
698
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
699
700

        super(FlashCausalLM, self).__init__(
701
            model=model,
702
703
704
705
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
706
707
            rank=rank,
            world_size=world_size,
708
709
710
711
712
713
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
    def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int):
        global CACHE_MANAGER

        torch.cuda.empty_cache()
        try:
            CACHE_MANAGER = CacheManager(
                # Adds some wiggle room
                math.ceil(max_total_tokens / BLOCK_SIZE) + 10,
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
                self.dtype,
                self.device,
            )
            _, batch = self.generate_token(batch)
        except Exception as e:
            logger.exception(
                f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
                f"prefill tokens. "
                f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
            )
            raise e
        del batch
737
        torch.cuda.empty_cache()
738

739
740
    def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
        return self.tokenizer.decode(
741
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
742
743
744
745
746
747
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
748
        cu_seqlen_prefill: Optional[torch.Tensor],
749
750
751
        block_tables: torch.Tensor,
        slots: torch.Tensor,
        input_lengths: torch.Tensor,
752
        max_s: int,
753
        lm_head_indices: Optional[torch.Tensor] = None,
754
    ) -> Tuple[torch.Tensor, torch.Tensor]:
755
756
        global CACHE_MANAGER

757
758
759
760
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
761
            cu_seqlen_prefill=cu_seqlen_prefill,
762
763
764
765
            kv_cache=CACHE_MANAGER.kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
766
            max_s=max_s,
767
            lm_head_indices=lm_head_indices,
768
769
770
771
772
773
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
774
        prefill = batch.cu_seqlen_prefill is not None
775
        prefill_logprobs = batch.prefill_next_token_indices is not None
776

777
778
779
        if batch.needed_blocks_slots:
            # Allocate blocks to this batch
            CACHE_MANAGER.allocate(batch)
780

781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
        try:
            out = self.forward(
                batch.input_ids,
                batch.position_ids,
                batch.cu_seqlen_prefill,
                batch.block_tables_tensor,
                batch.slots[batch.slot_indices],
                batch.input_lengths_tensor,
                batch.max_seqlen,
                batch.prefill_head_indices,
            )
        except Exception as e:
            del batch
            torch.cuda.empty_cache()
            raise e
796

797
798
        if prefill:
            next_token_logits = (
799
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
800
801
802
803
804
805
806
807
            )
        else:
            next_token_logits = out

        next_input_ids, next_token_logprobs = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
        )

808
        if prefill:
809
            if len(batch) > 1 and prefill_logprobs:
810
811
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
812
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
813
814

            next_position_ids = batch.position_ids.new_empty(len(batch))
815
816
817
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
818
819
820
821
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

822
823
824
825
826
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
827
        stopped = True
828
829
830
831
832
833
834

        # Zipped iterator
        iterator = zip(
            batch.input_lengths,
            batch.all_input_ids,
        )

835
836
837
838
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

839
840
841
842
843
        # For each member of the batch
        for i, (
            input_length,
            all_input_ids,
        ) in enumerate(iterator):
844
            # Indexing metadata
845
846
847
            start_index = cumulative_length
            end_index = cumulative_length + input_length

848
            if prefill:
849
850
851
852
853
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

854
855
856
857
858
859
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
860
861
862
863
864
865
866
867
868
869
                if prefill_logprobs:
                    if len(batch) > 1:
                        prefill_tokens_indices[
                            out_start_index : out_end_index - 1
                        ] = batch.input_ids[start_index + 1 : start_index + out_length]
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
870

871
            batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
872
873
874
875
876
877

            cumulative_length += input_length

        # Set values in batch
        batch.input_ids = next_input_ids
        batch.position_ids = next_position_ids + 1
878
879
        batch.input_lengths_tensor += 1
        batch.slot_indices += 1
880

881
        if prefill and prefill_logprobs:
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
        next_token_ids = batch.input_ids.tolist()

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
898
899
            batch.prefix_offsets,
            batch.read_offsets,
900
901
            batch.stopping_criterias,
            batch.all_input_ids,
902
903
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
904
905
906
907
908
909
910
911
            next_token_ids,
            next_token_logprobs,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
912
913
            prefix_offset,
            read_offset,
914
915
            stopping_criteria,
            all_input_ids,
916
917
            do_sample,
            seed,
918
919
920
            next_token_id,
            next_token_logprob,
        ) in enumerate(iterator):
921
            # Append next token to all tokens
922
            all_input_ids.append(next_token_id)
923
924

            # Generated token
925
            next_token_text, prefix_offset, read_offset = self.decode_token(
926
                all_input_ids,
927
928
                prefix_offset,
                read_offset,
929
930
931
932
            )

            # Evaluate stopping criteria
            stop, reason = stopping_criteria(
933
                next_token_id,
934
935
936
                next_token_text,
            )

937
            if not stop:
938
                stopped = False
939

940
941
942
943
944
945
946
947
948
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
                    output_text = self.decode(
                        all_input_ids[-stopping_criteria.current_tokens :]
                    )
                    generated_text = GeneratedText(
949
950
951
952
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
953
954
955
956
957
                    )
                else:
                    generated_text = None

                # Prefill
958
959
960
961
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

962
963
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
964
                        out_start_index : out_end_index - 1
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    prefill_tokens = PrefillTokens(
                        prefill_token_ids, request_prefill_logprobs, prefill_texts
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id,
                    next_token_logprob,
                    next_token_text,
                    next_token_id in self.all_special_ids,
                    generated_text,
986
987
                )

988
                generations.append(generation)
989

990
            # Update values
991
            batch.input_lengths[i] = input_length + 1
992
993
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
994
995
            batch.all_input_ids[i] = all_input_ids

996
997
998
999
1000
        if stopped:
            del batch
            # No need to return a batch if we know that all requests stopped
            return generations, None

1001
1002
1003
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1004
1005
        batch.max_seqlen = batch.max_seqlen + 1

1006
        return generations, batch