flash_causal_lm.py 35.2 KB
Newer Older
1
2
import math
import itertools
Nicolas Patry's avatar
Nicolas Patry committed
3
from text_generation_server.utils.tokens import batch_top_tokens
4
5
6
import torch
import torch.distributed

7
8
import numpy as np

9
10
from dataclasses import dataclass
from opentelemetry import trace
11
from transformers import PreTrainedTokenizerBase
12
from typing import Optional, Tuple, List, Type, Union, Dict
13
14
15
16
17
18
19

from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
Nicolas Patry's avatar
Nicolas Patry committed
20
    TopTokens,
21
)
22
23
24
25
26
from text_generation_server.models.cache_manager import (
    get_cache_manager,
    set_cache_manager,
    BLOCK_SIZE,
)
27
from text_generation_server.pb import generate_pb2
28
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
29
from text_generation_server.utils.dist import MEMORY_FRACTION
30
31
32
33
34
35
36
37

tracer = trace.get_tracer(__name__)


@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
38
39
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
40
41

    # Decoder values
42
43
44
    input_ids: torch.Tensor
    position_ids: torch.Tensor

45
46
47
48
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor
    # List of tuple of ints representing the number of blocks and slots needed by each sequence
    needed_blocks_slots: Optional[List[Tuple[int, int]]]

    # Set in prefill by the CacheManager
    # list of length b of list of length s_i // block_size
    block_tables: Optional[List[List[int]]]
    # tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
    block_tables_tensor: Optional[torch.Tensor]
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
    slots: Optional[torch.Tensor]

68
69
    max_seqlen: int

70
71
72
73
74
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

75
76
    # All tokens
    all_input_ids: List[List[int]]
77
    all_input_ids_tensor: torch.Tensor
78
79
80

    # Lengths of all generations present in the batch
    input_lengths: List[int]
81
    input_lengths_tensor: torch.Tensor
82
83
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
84
85

    # Generation helpers
86
    next_token_chooser: HeterogeneousNextTokenChooser
87
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
88
89
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
90

91
92
93
94
    # Number of blocks in this batch
    blocks: int
    # Maximum number of blocks
    max_blocks: int
95

96
97
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
98
            id=self.batch_id,
99
            request_ids=[r.id for r in self.requests],
100
            size=len(self),
101
            max_tokens=self.blocks * BLOCK_SIZE,
102
103
104
105
106
107
108
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
109
        dtype: torch.dtype,
110
        device: torch.device,
111
    ) -> "FlashCausalLMBatch":
112
113
114
115
116
117
118
119
120
121
        batch_inputs = []
        max_truncation = 0
        for r in pb.requests:
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]

122
        position_ids = []
123
        cu_seqlen_prefill = [0]
124
125
126
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
127
128

        input_lengths = []
129
130
        prefix_offsets = []
        read_offsets = []
131
        all_input_ids = []
132
        requests_idx_mapping = {}
133

134
135
136
137
138
139
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

140
        next_token_chooser_parameters = []
141
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
142
        top_n_tokens = []
143
144
145

        # Cumulative length
        cumulative_length = 0
146
        cumulative_max_length = 0
147
        prefill_out_cumulative_length = 0
148

149
150
        blocks = 0
        max_seqlen = 0
151
        max_length = 0
152
        max_blocks = 0
153

154
        # Parse batch
155
156
157
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
158
159
160
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

161
            tokenized_input = tokenized_input[-r.truncate :]
162

163
164
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
165

166
            prefix_offsets.append(input_length - 5)
167
            read_offsets.append(input_length)
168

169
            all_input_ids.append(tokenized_input)
170
171

            # Position ids
172
173
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
174
175

            # Add cumulative lengths of all previous inputs
176
            cu_seqlen_prefill.append(cumulative_length + input_length)
177

178
            next_token_chooser_parameters.append(r.parameters)
179

180
181
182
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
183
            max_new_tokens = stopping_criteria.max_new_tokens
184
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
185
            top_n_tokens.append(r.top_n_tokens)
186

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
            # Paged attention
            # Remove one as the first token des not have a past
            total_tokens = input_length + max_new_tokens - 1
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
            blocks += needed_blocks
            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

222
223
            # Update
            cumulative_length += input_length
224
225
226
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
227
228
229
230
231
            max_length = max(max_length, input_length + max_new_tokens)

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
            next_token_chooser_parameters, dtype, device
        )
232
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
233
234
235
236
237
238
239

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
240

241
242
243
244
245
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

246
247
248
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
249
            slot_indices = torch.cat(slot_indices)
250
251
252
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
253
            slot_indices = slot_indices[0]
254

255
256
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
257
        )
258

259
260
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
261
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
262
263
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
264
        )
265

266
267
        if all_prefill_logprobs:
            prefill_head_indices = None
268
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
269
        elif no_prefill_logprobs:
270
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
271
272
273
274
275
276
277
278
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
Nicolas Patry's avatar
Nicolas Patry committed
279
280
281
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
282

283
284
285
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
286
            requests_idx_mapping=requests_idx_mapping,
287
288
            input_ids=input_ids,
            position_ids=position_ids,
289
            cu_seqlen_prefill=cu_seqlen_prefill,
290
291
292
293
294
295
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
296
            max_seqlen=max_seqlen,
297
298
299
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
300
            input_lengths=input_lengths,
301
            input_lengths_tensor=input_lengths_tensor,
302
303
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
304
            all_input_ids=all_input_ids,
305
306
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
307
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
308
309
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
310
311
            blocks=blocks,
            max_blocks=max_blocks,
312
313
        )

314
    @tracer.start_as_current_span("filter")
315
316
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
317
318
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
319
        if len(request_ids) == len(self):
320
321
            return self

322
        device = self.input_ids.device
323

324
325
326
        # New values after filtering
        requests_idx_mapping = {}

327
328
329
        # Used to index into tensors
        indices = []

330
331
332
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
333
334
        )

335
        # Create on CPU to only move to GPU once instead of at every copy
336
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
337
338
        max_seqlen = 0

339
        requests = []
340
341
        start_slots = []
        block_tables = []
342
343
        all_input_ids = []

344
        input_lengths = []
345
346
        prefix_offsets = []
        read_offsets = []
347

348
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
349
        top_n_tokens = []
350

351
352
353
354
355
        blocks = 0
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

356
357
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
358
            indices.append(idx)
359
360
361
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
362
363
364
365

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
366

367
368
369
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
370
371
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
372

373
374
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
375

Nicolas Patry's avatar
Nicolas Patry committed
376
377
            top_n_tokens.append(self.top_n_tokens[idx])

378
            remaining_tokens = (
379
380
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
381

382
383
384
385
386
            request_block_table = self.block_tables[idx]
            blocks += len(request_block_table)
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

387
            # Copy to tensor (CPU)
388
            slot_indices[i] = cumulative_max_length + request_input_length - 1
389
390

            # Set slice
391
392
393
394
395
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
396
397
398
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
399

400
401
402
403
404
405
406
407
408
            max_blocks = max(max_blocks, len(request_block_table))

        block_indices_to_free = []
        # Iterate on all requests
        for i, r in enumerate(self.requests):
            # Filter requests that are not part of the new batch
            if r.id not in requests_idx_mapping.keys():
                block_indices_to_free.extend(self.block_tables[i])
        # Free blocks
409
        get_cache_manager().free(block_indices_to_free)
410
411
412
        # Needed to avoid dropping blocks when the batches will go out of scope
        self.block_tables = None

413
414
415
416
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
417
418
419
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
420
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
421
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
422
423

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
424

425
        # Move to GPU now that we have the whole tensor
426
        slot_indices = slot_indices.to(device)
427

428
        return type(self)(
429
430
431
432
433
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
434
            cu_seqlen_prefill=None,
435
436
437
438
439
440
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
441
            max_seqlen=max_seqlen,
442
443
444
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
445
            input_lengths=input_lengths,
446
            input_lengths_tensor=input_lengths_tensor,
447
448
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
449
450
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
451
            next_token_chooser=next_token_chooser,
452
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
453
454
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
455
456
            blocks=blocks,
            max_blocks=max_blocks,
457
458
459
460
461
462
463
464
465
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
        blocks = 0
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
            blocks += b.blocks
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
489
490
491

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
492
493
494
495
496
497
498
499
500
501
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
502
        )
Nicolas Patry's avatar
Nicolas Patry committed
503
504
505
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
506

507
508
        start_slots = []
        block_tables = []
509
510
511
        all_input_ids = []

        input_lengths = []
512
513
        prefix_offsets = []
        read_offsets = []
514

515
        next_token_chooser_parameters = []
516
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
517
        top_n_tokens = []
518

519
        # Cumulative length
520
        cumulative_batch_size = 0
521
        cumulative_slots = 0
522
523
524

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
525
526
527
528
529
530
531
532

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

533
534
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
535
536
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
537
538
539
540

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
541
542
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
543
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
544
            slots[slots_start_index:slots_end_index] = batch.slots
545

546
547
548
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
549

550
551
552
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
553

554
555
556
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
557
558
            all_input_ids.extend(batch.all_input_ids)

559
            input_lengths.extend(batch.input_lengths)
560
561
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
562

563
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
564
565
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
566
567
            top_n_tokens.extend(batch.top_n_tokens)

568
            # Update
569
            cumulative_batch_size += len(batch)
570
            cumulative_slots += len(batch.slots)
571

572
        start_slots = torch.concat(start_slots)
573

574
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
575
576
577
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
578
579
        )

580
581
582
        # Needed to avoid dropping blocks when the batches will go out of scope
        for b in batches:
            b.block_tables = None
583
            del b
584

585
        return cls(
586
587
            batch_id=batches[0].batch_id,
            requests=requests,
588
            requests_idx_mapping=requests_idx_mapping,
589
590
            input_ids=input_ids,
            position_ids=position_ids,
591
            cu_seqlen_prefill=None,
592
593
594
595
596
597
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
598
            max_seqlen=max_seqlen,
599
600
601
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
602
            input_lengths=input_lengths,
603
            input_lengths_tensor=input_lengths_tensor,
604
605
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
606
607
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
608
            next_token_chooser=next_token_chooser,
609
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
610
611
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
612
613
            blocks=blocks,
            max_blocks=max_blocks,
614
615
        )

616
617
618
    def __del__(self):
        if self.block_tables is not None and self.block_tables:
            # Free blocks
619
620
621
            get_cache_manager().free(
                list(itertools.chain.from_iterable(self.block_tables))
            )
622

623
624
625
626
627
628
629
    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
630
631
632
633
634
635
636
637
638
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        num_layers: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
639
        sliding_window: Optional[int] = None,
640
    ):
641
642
643
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
644
645

        super(FlashCausalLM, self).__init__(
646
            model=model,
647
648
649
650
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
651
652
            rank=rank,
            world_size=world_size,
653
            sliding_window=sliding_window,
654
655
656
657
658
659
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

660
    def warmup(self, batch: FlashCausalLMBatch):
661
662
        torch.cuda.empty_cache()
        try:
663
            cache_manager = set_cache_manager(
664
                batch.blocks,
665
666
667
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
668
                self.sliding_window is not None,
669
670
671
672
                self.dtype,
                self.device,
            )
            _, batch = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
673
        except torch.cuda.OutOfMemoryError as e:
674
            raise RuntimeError(
675
676
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
677
            ) from e
678
679
680

        torch.cuda.synchronize(self.device)

681
682
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
683
684
685
686
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

687
688
689
690
691
692
        total_free_memory, _ = torch.cuda.mem_get_info(self.device)
        total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory

        free_memory = max(
            0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
        )
693
694

        num_blocks = (
695
            int(free_memory // total_cache_size)
696
            # Add batch.blocks as we allocated it above, so it is included in the peak memory.
697
            + cache_manager.num_blocks
698
699
        )

700
        del batch
701
        del cache_manager
702

703
        set_cache_manager(
704
705
706
707
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
708
            self.sliding_window is not None,
709
710
711
712
713
            self.dtype,
            self.device,
        )

        return int(num_blocks * BLOCK_SIZE)
714

715
    def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
716
717
        # Model Forward
        return self.model.forward(
718
719
720
721
722
723
724
725
726
            input_ids=batch.input_ids,
            position_ids=batch.position_ids,
            cu_seqlen_prefill=batch.cu_seqlen_prefill,
            kv_cache=get_cache_manager().kv_cache,
            block_tables=batch.block_tables_tensor,
            slots=batch.slots[batch.slot_indices],
            input_lengths=batch.input_lengths_tensor,
            max_s=batch.max_seqlen,
            lm_head_indices=batch.prefill_head_indices,
727
728
729
730
731
732
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
733
        prefill = batch.cu_seqlen_prefill is not None
734
        prefill_logprobs = batch.prefill_next_token_indices is not None
735

736
737
        if batch.needed_blocks_slots:
            # Allocate blocks to this batch
738
739
740
741
742
743
744
745
746
747
            block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
                batch.needed_blocks_slots,
                batch.blocks,
                batch.max_blocks,
                batch.input_ids.device,
            )
            batch.needed_blocks_slots = None
            batch.block_tables = block_tables
            batch.block_tables_tensor = block_tables_tensor
            batch.slots = slots
748

749
        try:
750
            out = self.forward(batch)
751
752
753
        except Exception as e:
            del batch
            raise e
754

755
756
        if prefill:
            next_token_logits = (
757
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
758
759
760
761
            )
        else:
            next_token_logits = out

Nicolas Patry's avatar
Nicolas Patry committed
762
        next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
763
764
765
            batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
        )

Nicolas Patry's avatar
Nicolas Patry committed
766
767
768
769
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
        )

770
        if prefill:
771
            if len(batch) > 1 and prefill_logprobs:
772
773
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
774
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
775
776

            next_position_ids = batch.position_ids.new_empty(len(batch))
777
778
779
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
780
781
782
783
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

784
785
786
787
788
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
789
        stopped = True
790
791
792
793
794
795
796

        # Zipped iterator
        iterator = zip(
            batch.input_lengths,
            batch.all_input_ids,
        )

797
798
799
800
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

801
802
803
804
805
        # For each member of the batch
        for i, (
            input_length,
            all_input_ids,
        ) in enumerate(iterator):
806
            # Indexing metadata
807
808
809
            start_index = cumulative_length
            end_index = cumulative_length + input_length

810
            if prefill:
811
812
813
814
815
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

816
817
818
819
820
821
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
822
823
824
825
826
827
828
829
830
831
                if prefill_logprobs:
                    if len(batch) > 1:
                        prefill_tokens_indices[
                            out_start_index : out_end_index - 1
                        ] = batch.input_ids[start_index + 1 : start_index + out_length]
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
832

833
            batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
834
835
836
837
838
839

            cumulative_length += input_length

        # Set values in batch
        batch.input_ids = next_input_ids
        batch.position_ids = next_position_ids + 1
840
841
        batch.input_lengths_tensor += 1
        batch.slot_indices += 1
842

843
        if prefill and prefill_logprobs:
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
        next_token_ids = batch.input_ids.tolist()

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
860
861
            batch.prefix_offsets,
            batch.read_offsets,
862
863
            batch.stopping_criterias,
            batch.all_input_ids,
864
865
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
866
            batch.top_n_tokens,
867
868
            next_token_ids,
            next_token_logprobs,
Nicolas Patry's avatar
Nicolas Patry committed
869
870
            batch_top_token_ids,
            batch_top_token_logprobs,
871
872
873
874
875
876
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
877
878
            prefix_offset,
            read_offset,
879
880
            stopping_criteria,
            all_input_ids,
881
882
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
883
            top_n_tokens,
884
885
            next_token_id,
            next_token_logprob,
Nicolas Patry's avatar
Nicolas Patry committed
886
887
            top_token_ids,
            top_token_logprobs,
888
        ) in enumerate(iterator):
889
            # Append next token to all tokens
890
            all_input_ids.append(next_token_id)
891
892

            # Generated token
893
            next_token_text, prefix_offset, read_offset = self.decode_token(
894
                all_input_ids,
895
896
                prefix_offset,
                read_offset,
897
898
899
900
            )

            # Evaluate stopping criteria
            stop, reason = stopping_criteria(
901
                next_token_id,
902
903
904
                next_token_text,
            )

905
            if not stop:
906
                stopped = False
907

908
909
910
911
912
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
913
914
                    output_text, _, _ = self.decode_token(
                        all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
915
916
917
918
919
920
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
921
922
                    )
                    generated_text = GeneratedText(
923
924
925
926
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
927
928
929
930
931
                    )
                else:
                    generated_text = None

                # Prefill
932
933
934
935
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

936
937
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
938
                        out_start_index : out_end_index - 1
939
940
941
942
943
944
945
946
947
948
949
950
951
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    prefill_tokens = PrefillTokens(
                        prefill_token_ids, request_prefill_logprobs, prefill_texts
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
                if top_n_tokens > 0:
                    toptoken_texts = self.tokenizer.batch_decode(
                        top_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    special_toptokens = [
                        token_id in self.all_special_ids for token_id in top_token_ids
                    ]
                    top_tokens = TopTokens(
                        top_token_ids,
                        top_token_logprobs,
                        toptoken_texts,
                        special_toptokens,
                    )
                else:
                    top_tokens = None

970
971
972
973
974
975
976
977
                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id,
                    next_token_logprob,
                    next_token_text,
                    next_token_id in self.all_special_ids,
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
978
                    top_tokens,
979
980
                )

981
                generations.append(generation)
982

983
            # Update values
984
            batch.input_lengths[i] = input_length + 1
985
986
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
987
988
            batch.all_input_ids[i] = all_input_ids

989
990
991
992
993
        if stopped:
            del batch
            # No need to return a batch if we know that all requests stopped
            return generations, None

994
995
996
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
997
998
        batch.max_seqlen = batch.max_seqlen + 1

999
        return generations, batch