"examples/community/sd_text2img_k_diffusion.py" did not exist on "813744e5f3af32a81cf31427940d1a2d3abdf578"
flash_causal_lm.py 38.3 KB
Newer Older
1
2
import math
import itertools
Nicolas Patry's avatar
Nicolas Patry committed
3
from text_generation_server.utils.tokens import batch_top_tokens
4
5
6
import torch
import torch.distributed

7
8
import numpy as np

9
10
from dataclasses import dataclass
from opentelemetry import trace
11
from transformers import PreTrainedTokenizerBase
12
from typing import Optional, Tuple, List, Type, Union, Dict
13
14
15
16
17
18
19

from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
Nicolas Patry's avatar
Nicolas Patry committed
20
    TopTokens,
21
22
)
from text_generation_server.pb import generate_pb2
23
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
24
from text_generation_server.utils.dist import MEMORY_FRACTION
25
26
27

tracer = trace.get_tracer(__name__)

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
BLOCK_SIZE = 16
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None


class CacheManager:
    def __init__(
        self,
        num_blocks: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        self.block_size = BLOCK_SIZE
44
        self.num_blocks = num_blocks
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

        element_size = torch.tensor([], dtype=dtype).element_size()
        x = self.block_size // element_size

        self.kv_cache = [
            (
                torch.empty(
                    (num_blocks, num_heads, head_size // x, self.block_size, x),
                    dtype=dtype,
                    device=device,
                ),
                torch.empty(
                    (num_blocks, num_heads, head_size, self.block_size),
                    dtype=dtype,
                    device=device,
                ),
            )
            for _ in range(num_layers)
        ]
        self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
        self.slots = torch.arange(
            0, num_blocks * self.block_size, dtype=torch.int32
        ).view(num_blocks, self.block_size)

    def allocate(self, batch: "FlashCausalLMBatch"):
        # Get free blocks indices by finding values in mask that are not set to 0
        free_block_indices = self.free_block_mask.nonzero()
        assert (
            len(free_block_indices) >= batch.blocks
        ), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks"

        # Slice by the number of required blocks
        block_indices = free_block_indices[: batch.blocks]
        block_indices = block_indices.flatten()

        # Padded block tables
        block_tables_tensor = torch.zeros(
            (len(batch), batch.max_blocks), dtype=torch.int32
        )

        # Allocate paged attention blocks
        cumulative_blocks = 0
        slots = []
        block_tables = []
        for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots):
            # Get allocated blocks for this sequence
            allocated_blocks = block_indices[
                cumulative_blocks : cumulative_blocks + needed_blocks
            ]
            # Get slots for the allocated blocks
            allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots]

            slots.append(allocated_slots)
            block_tables.append(allocated_blocks.tolist())
            block_tables_tensor[i, :needed_blocks] = allocated_blocks
            cumulative_blocks += needed_blocks

        batch.needed_blocks_slots = None
        batch.block_tables = block_tables
        batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device)
        batch.slots = torch.concat(slots).to(batch.input_ids.device)

        # Allocate the required number of blocks by setting the mask to 0
        self.free_block_mask[block_indices] = 0

    def free(self, block_indices: Optional[List[int]]):
        if block_indices is not None and block_indices:
            # Reset mask
            self.free_block_mask[block_indices] = 1

115
116
117
118
119

@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
120
121
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
122
123

    # Decoder values
124
125
126
    input_ids: torch.Tensor
    position_ids: torch.Tensor

127
128
129
130
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor
    # List of tuple of ints representing the number of blocks and slots needed by each sequence
    needed_blocks_slots: Optional[List[Tuple[int, int]]]

    # Set in prefill by the CacheManager
    # list of length b of list of length s_i // block_size
    block_tables: Optional[List[List[int]]]
    # tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
    block_tables_tensor: Optional[torch.Tensor]
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
    slots: Optional[torch.Tensor]

150
151
    max_seqlen: int

152
153
154
155
156
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

157
158
    # All tokens
    all_input_ids: List[List[int]]
159
    all_input_ids_tensor: torch.Tensor
160
161
162

    # Lengths of all generations present in the batch
    input_lengths: List[int]
163
    input_lengths_tensor: torch.Tensor
164
165
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
166
167

    # Generation helpers
168
    next_token_chooser: HeterogeneousNextTokenChooser
169
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
170
171
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
172

173
174
175
176
    # Number of blocks in this batch
    blocks: int
    # Maximum number of blocks
    max_blocks: int
177

178
179
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
180
            id=self.batch_id,
181
            request_ids=[r.id for r in self.requests],
182
            size=len(self),
183
            max_tokens=self.blocks * BLOCK_SIZE,
184
185
186
187
188
189
190
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
191
        dtype: torch.dtype,
192
        device: torch.device,
193
    ) -> "FlashCausalLMBatch":
194
195
196
197
198
199
200
201
202
203
        batch_inputs = []
        max_truncation = 0
        for r in pb.requests:
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]

204
        position_ids = []
205
        cu_seqlen_prefill = [0]
206
207
208
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
209
210

        input_lengths = []
211
212
        prefix_offsets = []
        read_offsets = []
213
        all_input_ids = []
214
        requests_idx_mapping = {}
215

216
217
218
219
220
221
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

222
        next_token_chooser_parameters = []
223
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
224
        top_n_tokens = []
225
226
227

        # Cumulative length
        cumulative_length = 0
228
        cumulative_max_length = 0
229
        prefill_out_cumulative_length = 0
230

231
232
        blocks = 0
        max_seqlen = 0
233
        max_length = 0
234
        max_blocks = 0
235

236
        # Parse batch
237
238
239
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
240
241
242
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

243
            tokenized_input = tokenized_input[-r.truncate :]
244

245
246
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
247

248
            prefix_offsets.append(input_length - 5)
249
            read_offsets.append(input_length)
250

251
            all_input_ids.append(tokenized_input)
252
253

            # Position ids
254
255
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
256
257

            # Add cumulative lengths of all previous inputs
258
            cu_seqlen_prefill.append(cumulative_length + input_length)
259

260
            next_token_chooser_parameters.append(r.parameters)
261

262
263
264
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
265
            max_new_tokens = stopping_criteria.max_new_tokens
266
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
267
            top_n_tokens.append(r.top_n_tokens)
268

269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
            # Paged attention
            # Remove one as the first token des not have a past
            total_tokens = input_length + max_new_tokens - 1
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
            blocks += needed_blocks
            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

304
305
            # Update
            cumulative_length += input_length
306
307
308
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
309
310
311
312
313
            max_length = max(max_length, input_length + max_new_tokens)

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
            next_token_chooser_parameters, dtype, device
        )
314
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
315
316
317
318
319
320
321

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
322

323
324
325
326
327
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

328
329
330
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
331
            slot_indices = torch.cat(slot_indices)
332
333
334
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
335
            slot_indices = slot_indices[0]
336

337
338
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
339
        )
340

341
342
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
343
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
344
345
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
346
        )
347

348
349
        if all_prefill_logprobs:
            prefill_head_indices = None
350
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
351
        elif no_prefill_logprobs:
352
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
353
354
355
356
357
358
359
360
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
Nicolas Patry's avatar
Nicolas Patry committed
361
362
363
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
364

365
366
367
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
368
            requests_idx_mapping=requests_idx_mapping,
369
370
            input_ids=input_ids,
            position_ids=position_ids,
371
            cu_seqlen_prefill=cu_seqlen_prefill,
372
373
374
375
376
377
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
378
            max_seqlen=max_seqlen,
379
380
381
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
382
            input_lengths=input_lengths,
383
            input_lengths_tensor=input_lengths_tensor,
384
385
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
386
            all_input_ids=all_input_ids,
387
388
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
389
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
390
391
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
392
393
            blocks=blocks,
            max_blocks=max_blocks,
394
395
        )

396
    @tracer.start_as_current_span("filter")
397
398
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
399
400
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
401
        if len(request_ids) == len(self):
402
403
            return self

404
        device = self.input_ids.device
405

406
407
408
        # New values after filtering
        requests_idx_mapping = {}

409
410
411
        # Used to index into tensors
        indices = []

412
413
414
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
415
416
        )

417
        # Create on CPU to only move to GPU once instead of at every copy
418
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
419
420
        max_seqlen = 0

421
        requests = []
422
423
        start_slots = []
        block_tables = []
424
425
        all_input_ids = []

426
        input_lengths = []
427
428
        prefix_offsets = []
        read_offsets = []
429

430
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
431
        top_n_tokens = []
432

433
434
435
436
437
        blocks = 0
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

438
439
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
440
            indices.append(idx)
441
442
443
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
444
445
446
447

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
448

449
450
451
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
452
453
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
454

455
456
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
457

Nicolas Patry's avatar
Nicolas Patry committed
458
459
            top_n_tokens.append(self.top_n_tokens[idx])

460
            remaining_tokens = (
461
462
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
463

464
465
466
467
468
            request_block_table = self.block_tables[idx]
            blocks += len(request_block_table)
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

469
            # Copy to tensor (CPU)
470
            slot_indices[i] = cumulative_max_length + request_input_length - 1
471
472

            # Set slice
473
474
475
476
477
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
478
479
480
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
481

482
483
484
485
486
487
488
489
490
491
492
493
494
495
            max_blocks = max(max_blocks, len(request_block_table))

        global CACHE_MANAGER
        block_indices_to_free = []
        # Iterate on all requests
        for i, r in enumerate(self.requests):
            # Filter requests that are not part of the new batch
            if r.id not in requests_idx_mapping.keys():
                block_indices_to_free.extend(self.block_tables[i])
        # Free blocks
        CACHE_MANAGER.free(block_indices_to_free)
        # Needed to avoid dropping blocks when the batches will go out of scope
        self.block_tables = None

496
497
498
499
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
500
501
502
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
503
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
504
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
505
506

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
507

508
        # Move to GPU now that we have the whole tensor
509
        slot_indices = slot_indices.to(device)
510

511
512
513
514
515
516
        return FlashCausalLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
517
            cu_seqlen_prefill=None,
518
519
520
521
522
523
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
524
            max_seqlen=max_seqlen,
525
526
527
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
528
            input_lengths=input_lengths,
529
            input_lengths_tensor=input_lengths_tensor,
530
531
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
532
533
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
534
            next_token_chooser=next_token_chooser,
535
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
536
537
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
538
539
            blocks=blocks,
            max_blocks=max_blocks,
540
541
542
543
544
545
546
547
548
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
        blocks = 0
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
            blocks += b.blocks
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
572
573
574

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
575
576
577
578
579
580
581
582
583
584
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
585
        )
Nicolas Patry's avatar
Nicolas Patry committed
586
587
588
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
589

590
591
        start_slots = []
        block_tables = []
592
593
594
        all_input_ids = []

        input_lengths = []
595
596
        prefix_offsets = []
        read_offsets = []
597

598
        next_token_chooser_parameters = []
599
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
600
        top_n_tokens = []
601

602
        # Cumulative length
603
        cumulative_batch_size = 0
604
        cumulative_slots = 0
605
606
607

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
608
609
610
611
612
613
614
615

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

616
617
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
618
619
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
620
621
622
623

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
624
625
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
626
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
627
            slots[slots_start_index:slots_end_index] = batch.slots
628

629
630
631
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
632

633
634
635
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
636

637
638
639
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
640
641
            all_input_ids.extend(batch.all_input_ids)

642
            input_lengths.extend(batch.input_lengths)
643
644
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
645

646
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
647
648
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
649
650
            top_n_tokens.extend(batch.top_n_tokens)

651
            # Update
652
            cumulative_batch_size += len(batch)
653
            cumulative_slots += len(batch.slots)
654

655
        start_slots = torch.concat(start_slots)
656

657
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
658
659
660
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
661
662
        )

663
664
665
        # Needed to avoid dropping blocks when the batches will go out of scope
        for b in batches:
            b.block_tables = None
666
            del b
667

668
669
670
        return FlashCausalLMBatch(
            batch_id=batches[0].batch_id,
            requests=requests,
671
            requests_idx_mapping=requests_idx_mapping,
672
673
            input_ids=input_ids,
            position_ids=position_ids,
674
            cu_seqlen_prefill=None,
675
676
677
678
679
680
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
681
            max_seqlen=max_seqlen,
682
683
684
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
685
            input_lengths=input_lengths,
686
            input_lengths_tensor=input_lengths_tensor,
687
688
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
689
690
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
691
            next_token_chooser=next_token_chooser,
692
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
693
694
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
695
696
            blocks=blocks,
            max_blocks=max_blocks,
697
698
        )

699
700
701
702
703
704
    def __del__(self):
        if self.block_tables is not None and self.block_tables:
            global CACHE_MANAGER
            # Free blocks
            CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))

705
706
707
708
709
710
711
    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
712
713
714
715
716
717
718
719
720
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        num_layers: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
721
    ):
722
723
724
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
725
726

        super(FlashCausalLM, self).__init__(
727
            model=model,
728
729
730
731
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
732
733
            rank=rank,
            world_size=world_size,
734
735
736
737
738
739
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

740
    def warmup(self, batch: FlashCausalLMBatch):
741
742
743
744
745
        global CACHE_MANAGER

        torch.cuda.empty_cache()
        try:
            CACHE_MANAGER = CacheManager(
746
                batch.blocks,
747
748
749
750
751
752
753
754
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
                self.dtype,
                self.device,
            )
            _, batch = self.generate_token(batch)
        except Exception as e:
755
            raise RuntimeError(
756
757
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
758
            ) from e
759
760
761

        torch.cuda.synchronize(self.device)

762
763
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
764
765
766
767
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

768
769
770
771
772
773
        total_free_memory, _ = torch.cuda.mem_get_info(self.device)
        total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory

        free_memory = max(
            0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
        )
774
775

        num_blocks = (
776
            int(free_memory // total_cache_size)
777
            # Add batch.blocks as we allocated it above, so it is included in the peak memory.
778
            + CACHE_MANAGER.num_blocks
779
780
781
        )

        del CACHE_MANAGER
782
        del batch
783
784
785
786
787
788
789
790
791
792
793
794
        torch.cuda.empty_cache()

        CACHE_MANAGER = CacheManager(
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
            self.dtype,
            self.device,
        )

        return int(num_blocks * BLOCK_SIZE)
795

796
797
798
799
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
800
        cu_seqlen_prefill: Optional[torch.Tensor],
801
802
803
        block_tables: torch.Tensor,
        slots: torch.Tensor,
        input_lengths: torch.Tensor,
804
        max_s: int,
805
        lm_head_indices: Optional[torch.Tensor] = None,
806
    ) -> Tuple[torch.Tensor, torch.Tensor]:
807
808
        global CACHE_MANAGER

809
810
811
812
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
813
            cu_seqlen_prefill=cu_seqlen_prefill,
814
815
816
817
            kv_cache=CACHE_MANAGER.kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
818
            max_s=max_s,
819
            lm_head_indices=lm_head_indices,
820
821
822
823
824
825
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
826
        prefill = batch.cu_seqlen_prefill is not None
827
        prefill_logprobs = batch.prefill_next_token_indices is not None
828

829
830
831
        if batch.needed_blocks_slots:
            # Allocate blocks to this batch
            CACHE_MANAGER.allocate(batch)
832

833
834
835
836
837
838
839
840
841
842
843
844
845
846
        try:
            out = self.forward(
                batch.input_ids,
                batch.position_ids,
                batch.cu_seqlen_prefill,
                batch.block_tables_tensor,
                batch.slots[batch.slot_indices],
                batch.input_lengths_tensor,
                batch.max_seqlen,
                batch.prefill_head_indices,
            )
        except Exception as e:
            del batch
            raise e
847

848
849
        if prefill:
            next_token_logits = (
850
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
851
852
853
854
            )
        else:
            next_token_logits = out

Nicolas Patry's avatar
Nicolas Patry committed
855
        next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
856
857
858
            batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
        )

Nicolas Patry's avatar
Nicolas Patry committed
859
860
861
862
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
        )

863
        if prefill:
864
            if len(batch) > 1 and prefill_logprobs:
865
866
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
867
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
868
869

            next_position_ids = batch.position_ids.new_empty(len(batch))
870
871
872
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
873
874
875
876
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

877
878
879
880
881
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
882
        stopped = True
883
884
885
886
887
888
889

        # Zipped iterator
        iterator = zip(
            batch.input_lengths,
            batch.all_input_ids,
        )

890
891
892
893
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

894
895
896
897
898
        # For each member of the batch
        for i, (
            input_length,
            all_input_ids,
        ) in enumerate(iterator):
899
            # Indexing metadata
900
901
902
            start_index = cumulative_length
            end_index = cumulative_length + input_length

903
            if prefill:
904
905
906
907
908
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

909
910
911
912
913
914
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
915
916
917
918
919
920
921
922
923
924
                if prefill_logprobs:
                    if len(batch) > 1:
                        prefill_tokens_indices[
                            out_start_index : out_end_index - 1
                        ] = batch.input_ids[start_index + 1 : start_index + out_length]
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
925

926
            batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
927
928
929
930
931
932

            cumulative_length += input_length

        # Set values in batch
        batch.input_ids = next_input_ids
        batch.position_ids = next_position_ids + 1
933
934
        batch.input_lengths_tensor += 1
        batch.slot_indices += 1
935

936
        if prefill and prefill_logprobs:
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
        next_token_ids = batch.input_ids.tolist()

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
953
954
            batch.prefix_offsets,
            batch.read_offsets,
955
956
            batch.stopping_criterias,
            batch.all_input_ids,
957
958
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
959
            batch.top_n_tokens,
960
961
            next_token_ids,
            next_token_logprobs,
Nicolas Patry's avatar
Nicolas Patry committed
962
963
            batch_top_token_ids,
            batch_top_token_logprobs,
964
965
966
967
968
969
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
970
971
            prefix_offset,
            read_offset,
972
973
            stopping_criteria,
            all_input_ids,
974
975
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
976
            top_n_tokens,
977
978
            next_token_id,
            next_token_logprob,
Nicolas Patry's avatar
Nicolas Patry committed
979
980
            top_token_ids,
            top_token_logprobs,
981
        ) in enumerate(iterator):
982
            # Append next token to all tokens
983
            all_input_ids.append(next_token_id)
984
985

            # Generated token
986
            next_token_text, prefix_offset, read_offset = self.decode_token(
987
                all_input_ids,
988
989
                prefix_offset,
                read_offset,
990
991
992
993
            )

            # Evaluate stopping criteria
            stop, reason = stopping_criteria(
994
                next_token_id,
995
996
997
                next_token_text,
            )

998
            if not stop:
999
                stopped = False
1000

1001
1002
1003
1004
1005
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
1006
1007
                    output_text, _, _ = self.decode_token(
                        all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1008
1009
1010
1011
1012
1013
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
1014
1015
                    )
                    generated_text = GeneratedText(
1016
1017
1018
1019
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
1020
1021
1022
1023
1024
                    )
                else:
                    generated_text = None

                # Prefill
1025
1026
1027
1028
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

1029
1030
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
1031
                        out_start_index : out_end_index - 1
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    prefill_tokens = PrefillTokens(
                        prefill_token_ids, request_prefill_logprobs, prefill_texts
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
                if top_n_tokens > 0:
                    toptoken_texts = self.tokenizer.batch_decode(
                        top_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    special_toptokens = [
                        token_id in self.all_special_ids for token_id in top_token_ids
                    ]
                    top_tokens = TopTokens(
                        top_token_ids,
                        top_token_logprobs,
                        toptoken_texts,
                        special_toptokens,
                    )
                else:
                    top_tokens = None

1063
1064
1065
1066
1067
1068
1069
1070
                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id,
                    next_token_logprob,
                    next_token_text,
                    next_token_id in self.all_special_ids,
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
1071
                    top_tokens,
1072
1073
                )

1074
                generations.append(generation)
1075

1076
            # Update values
1077
            batch.input_lengths[i] = input_length + 1
1078
1079
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1080
1081
            batch.all_input_ids[i] = all_input_ids

1082
1083
1084
1085
1086
        if stopped:
            del batch
            # No need to return a batch if we know that all requests stopped
            return generations, None

1087
1088
1089
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1090
1091
        batch.max_seqlen = batch.max_seqlen + 1

1092
        return generations, batch