"docs/vscode:/vscode.git/clone" did not exist on "7e43f29d5269858729509c0de83a124c4a6ee650"
flash_causal_lm.py 40.4 KB
Newer Older
1
import math
2
import time
3
import itertools
4
5
6
import torch
import torch.distributed

7
8
import numpy as np

9
10
from dataclasses import dataclass
from opentelemetry import trace
11
from transformers import PreTrainedTokenizerBase
12
from typing import Optional, Tuple, List, Type, Dict
13

OlivierDehaene's avatar
OlivierDehaene committed
14
from text_generation_server.models import Model
15
from text_generation_server.utils.tokens import batch_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
16
from text_generation_server.utils.speculate import get_speculate
17
18
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
19
    Tokens,
20
21
22
    Generation,
    GeneratedText,
)
23
24
25
26
27
from text_generation_server.models.cache_manager import (
    get_cache_manager,
    set_cache_manager,
    BLOCK_SIZE,
)
28
from text_generation_server.pb import generate_pb2
29
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
30
from text_generation_server.utils.dist import MEMORY_FRACTION
31
32
33
34
35
36
37
38

tracer = trace.get_tracer(__name__)


@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
39
40
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
41
42

    # Decoder values
43
44
    input_ids: torch.Tensor
    position_ids: torch.Tensor
Nicolas Patry's avatar
Nicolas Patry committed
45
    speculative_ids: torch.Tensor
46

47
48
49
50
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor
    # List of tuple of ints representing the number of blocks and slots needed by each sequence
    needed_blocks_slots: Optional[List[Tuple[int, int]]]

    # Set in prefill by the CacheManager
    # list of length b of list of length s_i // block_size
    block_tables: Optional[List[List[int]]]
    # tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
    block_tables_tensor: Optional[torch.Tensor]
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
    slots: Optional[torch.Tensor]

70
71
    max_seqlen: int

72
73
74
75
76
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

77
78
    # All tokens
    all_input_ids: List[List[int]]
79
    all_input_ids_tensor: torch.Tensor
80
81
82

    # Lengths of all generations present in the batch
    input_lengths: List[int]
83
    input_lengths_tensor: torch.Tensor
84
85
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
86
87

    # Generation helpers
88
    next_token_chooser: HeterogeneousNextTokenChooser
89
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
90
91
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
92

93
94
95
96
    # Number of blocks in this batch
    blocks: int
    # Maximum number of blocks
    max_blocks: int
97

98
99
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
100
            id=self.batch_id,
101
            request_ids=[r.id for r in self.requests],
102
            size=len(self),
103
            max_tokens=self.blocks * BLOCK_SIZE,
104
105
106
107
108
109
110
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
111
        dtype: torch.dtype,
112
        device: torch.device,
113
    ) -> "FlashCausalLMBatch":
114
115
116
117
118
119
120
121
122
123
        batch_inputs = []
        max_truncation = 0
        for r in pb.requests:
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]

124
        position_ids = []
Nicolas Patry's avatar
Nicolas Patry committed
125
        speculative_ids = []
126
        cu_seqlen_prefill = [0]
127
128
129
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
130
131

        input_lengths = []
132
133
        prefix_offsets = []
        read_offsets = []
134
        all_input_ids = []
135
        requests_idx_mapping = {}
136

137
138
139
140
141
142
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

143
        next_token_chooser_parameters = []
144
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
145
        top_n_tokens = []
146
147
148

        # Cumulative length
        cumulative_length = 0
149
        cumulative_max_length = 0
150
        prefill_out_cumulative_length = 0
151

152
153
        blocks = 0
        max_seqlen = 0
154
        max_length = 0
155
        max_blocks = 0
156

157
        # Parse batch
158
159
160
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
161
162
163
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

164
            tokenized_input = tokenized_input[-r.truncate :]
165

166
167
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
168

169
            prefix_offsets.append(input_length - 5)
170
            read_offsets.append(input_length)
171

172
            all_input_ids.append(tokenized_input)
173
174

            # Position ids
175
176
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
177
178

            # Add cumulative lengths of all previous inputs
179
            cu_seqlen_prefill.append(cumulative_length + input_length)
180

181
            next_token_chooser_parameters.append(r.parameters)
182

183
184
185
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
186
            max_new_tokens = stopping_criteria.max_new_tokens
187
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
188
            top_n_tokens.append(r.top_n_tokens)
189

190
191
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
192
193
            speculative_length = get_speculate()
            total_tokens = input_length + max_new_tokens - 1 + speculative_length
194
195
196
197
198
199
200
201
202
203
204
205
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
            blocks += needed_blocks
            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

226
227
            # Update
            cumulative_length += input_length
228
229
230
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
OlivierDehaene's avatar
OlivierDehaene committed
231
232
233
            max_length = max(
                max_length, input_length + max_new_tokens + speculative_length
            )
234
235
236
237

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
            next_token_chooser_parameters, dtype, device
        )
238
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
239
240
241
242
243
244
245

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
246

247
248
249
250
251
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

252
253
254
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
255
            slot_indices = torch.cat(slot_indices)
256
257
258
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
259
            slot_indices = slot_indices[0]
260

261
262
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
263
264
265
        )
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
266
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
267
268
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
269
        )
270

271
272
        if all_prefill_logprobs:
            prefill_head_indices = None
273
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
274
        elif no_prefill_logprobs:
275
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
276
277
278
279
280
281
282
283
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
Nicolas Patry's avatar
Nicolas Patry committed
284
285
286
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
287

288
289
290
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
291
            requests_idx_mapping=requests_idx_mapping,
292
293
            input_ids=input_ids,
            position_ids=position_ids,
294
            cu_seqlen_prefill=cu_seqlen_prefill,
295
296
297
298
299
300
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
301
            max_seqlen=max_seqlen,
302
303
304
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
305
            input_lengths=input_lengths,
306
            input_lengths_tensor=input_lengths_tensor,
307
308
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
309
            all_input_ids=all_input_ids,
310
311
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
312
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
313
314
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
315
316
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
317
            speculative_ids=None,
318
319
        )

320
    @tracer.start_as_current_span("filter")
321
322
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
323
324
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
325
        if len(request_ids) == len(self):
326
327
            return self

328
        device = self.input_ids.device
329

330
331
332
        # New values after filtering
        requests_idx_mapping = {}

333
334
335
        # Used to index into tensors
        indices = []

336
337
338
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
339
340
        )

341
        # Create on CPU to only move to GPU once instead of at every copy
342
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
343
344
        max_seqlen = 0

345
        requests = []
346
347
        start_slots = []
        block_tables = []
348
349
        all_input_ids = []

350
        input_lengths = []
351
352
        prefix_offsets = []
        read_offsets = []
353

354
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
355
        top_n_tokens = []
356

357
358
359
360
361
        blocks = 0
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

362
363
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
364
            indices.append(idx)
365
366
367
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
368
369
370
371

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
372

373
374
375
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
376
377
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
378

379
380
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
381

Nicolas Patry's avatar
Nicolas Patry committed
382
383
            top_n_tokens.append(self.top_n_tokens[idx])

384
            remaining_tokens = (
385
386
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
387

388
389
390
391
392
            request_block_table = self.block_tables[idx]
            blocks += len(request_block_table)
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

393
            # Copy to tensor (CPU)
394
            slot_indices[i] = cumulative_max_length + request_input_length - 1
395
396

            # Set slice
397
398
399
400
401
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
402
403
404
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
405

406
407
408
409
410
411
412
413
414
            max_blocks = max(max_blocks, len(request_block_table))

        block_indices_to_free = []
        # Iterate on all requests
        for i, r in enumerate(self.requests):
            # Filter requests that are not part of the new batch
            if r.id not in requests_idx_mapping.keys():
                block_indices_to_free.extend(self.block_tables[i])
        # Free blocks
415
        get_cache_manager().free(block_indices_to_free)
416
417
418
        # Needed to avoid dropping blocks when the batches will go out of scope
        self.block_tables = None

419
420
421
422
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
423
424
425
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
426
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
427
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
OlivierDehaene's avatar
OlivierDehaene committed
428
429
430
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
431
432

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
433

434
        # Move to GPU now that we have the whole tensor
435
        slot_indices = slot_indices.to(device)
436

437
        return type(self)(
438
439
440
441
442
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
443
            cu_seqlen_prefill=None,
444
445
446
447
448
449
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
450
            max_seqlen=max_seqlen,
451
452
453
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
454
            input_lengths=input_lengths,
455
            input_lengths_tensor=input_lengths_tensor,
456
457
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
458
459
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
460
            next_token_chooser=next_token_chooser,
461
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
462
463
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
464
465
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
466
            speculative_ids=speculative_ids,
467
468
469
470
471
472
473
474
475
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

476
477
478
479
480
481
482
483
484
485
        blocks = 0
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
            blocks += b.blocks
OlivierDehaene's avatar
OlivierDehaene committed
486
487
488
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
489
490
491
492
493
494
495
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
496
                    + speculative_length
497
498
499
500
501
502
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
503
504
505

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
506
507
508
509
510
511
512
513
514
515
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
516
        )
Nicolas Patry's avatar
Nicolas Patry committed
517
518
519
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
520

521
522
        start_slots = []
        block_tables = []
523
524
525
        all_input_ids = []

        input_lengths = []
526
527
        prefix_offsets = []
        read_offsets = []
528

529
        next_token_chooser_parameters = []
530
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
531
        top_n_tokens = []
532

533
        # Cumulative length
534
        cumulative_batch_size = 0
535
        cumulative_slots = 0
536
537
538

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
539
540
541
542
543
544
545
546

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

547
548
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
549
550
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
551
552
553
554

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
555
556
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
557
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
558
            slots[slots_start_index:slots_end_index] = batch.slots
559

560
561
562
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
563

564
565
566
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
567

568
569
570
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
571
572
            all_input_ids.extend(batch.all_input_ids)

573
            input_lengths.extend(batch.input_lengths)
574
575
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
576

577
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
578
579
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
580
581
            top_n_tokens.extend(batch.top_n_tokens)

582
            # Update
583
            cumulative_batch_size += len(batch)
584
            cumulative_slots += len(batch.slots)
585

586
        start_slots = torch.concat(start_slots)
587

588
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
589
590
591
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
592
593
        )

OlivierDehaene's avatar
OlivierDehaene committed
594
595
596
597
598
        speculative_ids = (
            torch.cat([b.speculative_ids for b in batches], dim=0)
            if batches[0].speculative_ids is not None
            else None
        )
Nicolas Patry's avatar
Nicolas Patry committed
599

600
601
602
        # Needed to avoid dropping blocks when the batches will go out of scope
        for b in batches:
            b.block_tables = None
603
            del b
604

605
        return cls(
606
607
            batch_id=batches[0].batch_id,
            requests=requests,
608
            requests_idx_mapping=requests_idx_mapping,
609
610
            input_ids=input_ids,
            position_ids=position_ids,
611
            cu_seqlen_prefill=None,
612
613
614
615
616
617
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
618
            max_seqlen=max_seqlen,
619
620
621
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
622
            input_lengths=input_lengths,
623
            input_lengths_tensor=input_lengths_tensor,
624
625
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
626
627
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
628
            next_token_chooser=next_token_chooser,
629
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
630
631
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
632
633
            blocks=blocks,
            max_blocks=max_blocks,
OlivierDehaene's avatar
OlivierDehaene committed
634
            speculative_ids=speculative_ids,
635
636
        )

637
638
639
    def __del__(self):
        if self.block_tables is not None and self.block_tables:
            # Free blocks
640
641
642
            get_cache_manager().free(
                list(itertools.chain.from_iterable(self.block_tables))
            )
643

644
645
646
647
648
649
650
    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
651
652
653
654
655
656
657
658
659
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        num_layers: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
660
        sliding_window: Optional[int] = None,
661
    ):
662
663
664
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
665
666

        super(FlashCausalLM, self).__init__(
667
            model=model,
668
669
670
671
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
672
673
            rank=rank,
            world_size=world_size,
674
            sliding_window=sliding_window,
675
676
677
678
679
680
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

681
    def warmup(self, batch: FlashCausalLMBatch):
682
683
        torch.cuda.empty_cache()
        try:
684
            cache_manager = set_cache_manager(
685
                batch.blocks,
686
687
688
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
689
                self.sliding_window is not None,
690
691
692
                self.dtype,
                self.device,
            )
693
            _, batch, _ = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
694
        except torch.cuda.OutOfMemoryError as e:
695
            raise RuntimeError(
696
697
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
698
            ) from e
699
700
701

        torch.cuda.synchronize(self.device)

702
703
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
704
705
706
707
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

708
709
710
711
712
713
        total_free_memory, _ = torch.cuda.mem_get_info(self.device)
        total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory

        free_memory = max(
            0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
        )
714
715

        num_blocks = (
716
            int(free_memory // total_cache_size)
717
            # Add batch.blocks as we allocated it above, so it is included in the peak memory.
718
            + cache_manager.num_blocks
719
720
        )

721
        del batch
722
        del cache_manager
723

724
        set_cache_manager(
725
726
727
728
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
729
            self.sliding_window is not None,
730
731
732
733
734
            self.dtype,
            self.device,
        )

        return int(num_blocks * BLOCK_SIZE)
735

736
    def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
737
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
738
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
739
740
741
742
743
744
745
746
747
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
748
749
750

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
751
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
752
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
753
754
755
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
756
757
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
758
759
760
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
761
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
OlivierDehaene's avatar
OlivierDehaene committed
762
763
764
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
765
766

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
767
768
769
770
771
772
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
773
774
775
776
777
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
778
779
780
781
782
783
784
785
786
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
787

788
        return self.model.forward(
Nicolas Patry's avatar
Nicolas Patry committed
789
790
791
792
793
794
795
796
797
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=cu_seqlen_prefill,
            kv_cache=kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
            max_s=max_s,
            lm_head_indices=lm_head_indices,
798
799
800
801
802
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
803
804
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
805
        prefill = batch.cu_seqlen_prefill is not None
806
        prefill_logprobs = batch.prefill_next_token_indices is not None
807

808
809
        if batch.needed_blocks_slots:
            # Allocate blocks to this batch
810
811
812
813
814
815
816
817
818
819
            block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
                batch.needed_blocks_slots,
                batch.blocks,
                batch.max_blocks,
                batch.input_ids.device,
            )
            batch.needed_blocks_slots = None
            batch.block_tables = block_tables
            batch.block_tables_tensor = block_tables_tensor
            batch.slots = slots
820

821
        try:
822
            out = self.forward(batch)
823
824
825
        except Exception as e:
            del batch
            raise e
826

Nicolas Patry's avatar
Nicolas Patry committed
827
828
829
830
831
        if isinstance(out, tuple):
            out, speculative_logits = out
        else:
            speculative_logits = None

832
833
        if prefill:
            next_token_logits = (
834
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
835
            )
Nicolas Patry's avatar
Nicolas Patry committed
836
837
            if speculative_logits is not None:
                speculative_logits = (
OlivierDehaene's avatar
OlivierDehaene committed
838
839
840
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
Nicolas Patry's avatar
Nicolas Patry committed
841
                )
842
843
844
        else:
            next_token_logits = out

Nicolas Patry's avatar
Nicolas Patry committed
845
        speculate = get_speculate()
OlivierDehaene's avatar
OlivierDehaene committed
846
847
848
849
850
851
852
853
854
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen],
            next_token_logits,
Nicolas Patry's avatar
Nicolas Patry committed
855
            speculate,
OlivierDehaene's avatar
OlivierDehaene committed
856
857
            batch.speculative_ids,
            speculative_logits,
858
859
        )

Nicolas Patry's avatar
Nicolas Patry committed
860
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
861
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
862
863
        )

864
        if prefill:
865
            if len(batch) > 1 and prefill_logprobs:
866
867
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
868
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
869
870

            next_position_ids = batch.position_ids.new_empty(len(batch))
871
872
873
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
874
875
876
877
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

878
879
880
881
882
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
883
        stopped = True
884
885

        # Zipped iterator
OlivierDehaene's avatar
OlivierDehaene committed
886
        iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
887

888
889
890
891
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

892
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
893
        index = 0
OlivierDehaene's avatar
OlivierDehaene committed
894
        for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
895
            # Indexing metadata
896
897
898
            start_index = cumulative_length
            end_index = cumulative_length + input_length

899
            if prefill:
900
901
902
903
904
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

905
906
907
908
909
910
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
911
912
913
914
915
916
917
918
919
920
                if prefill_logprobs:
                    if len(batch) > 1:
                        prefill_tokens_indices[
                            out_start_index : out_end_index - 1
                        ] = batch.input_ids[start_index + 1 : start_index + out_length]
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
921

Nicolas Patry's avatar
Nicolas Patry committed
922
923
924
            for j in range(n_accepted_ids):
                batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
                index += 1
925
926
927

            cumulative_length += input_length

Nicolas Patry's avatar
Nicolas Patry committed
928
929
930
931
932
        batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
        batch.speculative_ids = speculative_ids
        batch.position_ids = next_position_ids + accepted_ids
        batch.input_lengths_tensor += accepted_ids
        batch.slot_indices += accepted_ids
933

934
        if prefill and prefill_logprobs:
935
936
937
938
939
940
941
942
943
944
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
945
        next_token_ids = next_input_ids.tolist()
946
947
        accepted_ids = accepted_ids.tolist()
        start_decode = time.time_ns()
948
949
950
951
952

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
953
954
            batch.prefix_offsets,
            batch.read_offsets,
955
956
            batch.stopping_criterias,
            batch.all_input_ids,
957
958
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
959
            batch.top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
960
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
961
962
            batch_top_token_ids,
            batch_top_token_logprobs,
963
964
965
        )

        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
966
        index = 0
967
968
969
        for i, (
            request,
            input_length,
970
971
            prefix_offset,
            read_offset,
972
973
            stopping_criteria,
            all_input_ids,
974
975
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
976
            top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
977
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
978
979
            top_token_ids,
            top_token_logprobs,
980
        ) in enumerate(iterator):
981
            # Append next token to all tokens
Nicolas Patry's avatar
Nicolas Patry committed
982
983
984
985
986
987
988
989
990
991
992
993
994
995
            next_token_texts = []
            left = 0

            current_stopped = False
            for j in range(index, index + n_accepted_ids):
                # Generated token
                next_token_id = next_token_ids[j]
                all_input_ids.append(next_token_id)
                next_token_text, prefix_offset, read_offset = self.decode_token(
                    all_input_ids,
                    prefix_offset,
                    read_offset,
                )
                next_token_texts.append(next_token_text)
996

Nicolas Patry's avatar
Nicolas Patry committed
997
998
999
1000
                stop, reason = stopping_criteria(
                    next_token_id,
                    next_token_text,
                )
1001

Nicolas Patry's avatar
Nicolas Patry committed
1002
1003
1004
1005
1006
1007
1008
                if stop:
                    left = index + n_accepted_ids - j - 1
                    current_stopped = True
                    break
                else:
                    current_stopped = False
            stopped = stopped and current_stopped
1009

OlivierDehaene's avatar
OlivierDehaene committed
1010
1011
1012
1013
            _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
            _next_token_logprobs = next_token_logprobs[
                index : index + n_accepted_ids - left
            ]
Nicolas Patry's avatar
Nicolas Patry committed
1014
            index += n_accepted_ids
1015

1016
1017
1018
1019
1020
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
1021
1022
                    output_text, _, _ = self.decode_token(
                        all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1023
1024
1025
1026
1027
1028
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
1029
1030
                    )
                    generated_text = GeneratedText(
1031
1032
1033
1034
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
1035
1036
1037
1038
1039
                    )
                else:
                    generated_text = None

                # Prefill
1040
1041
1042
1043
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

1044
1045
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
1046
                        out_start_index : out_end_index - 1
1047
1048
1049
1050
1051
1052
1053
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1054
1055

                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
1056
1057
1058
1059
                        prefill_token_ids,
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
1060
1061
1062
1063
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
1064
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
1065
                    all_top_tokens = []
1066
1067
1068
                    for (top_token_ids, top_token_logprobs) in zip(
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
1069
1070
1071
1072
1073
1074
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
1075
1076
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
1077
1078
1079
1080
1081
1082
1083
1084
1085
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
1086
1087
1088
                else:
                    top_tokens = None

1089
1090
1091
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1092
1093
1094
1095
1096
1097
                    Tokens(
                        _next_token_ids,
                        _next_token_logprobs,
                        next_token_texts,
                        [nid in self.all_special_ids for nid in _next_token_ids],
                    ),
1098
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
1099
                    top_tokens,
1100
1101
                )

1102
                generations.append(generation)
1103

1104
            # Update values
1105
            batch.input_lengths[i] = input_length + n_accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1106
1107
            if batch.input_lengths[i] > batch.max_seqlen:
                batch.max_seqlen = batch.input_lengths[i]
1108
1109
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1110
1111
            batch.all_input_ids[i] = all_input_ids

1112
1113
1114
        if stopped:
            del batch
            # No need to return a batch if we know that all requests stopped
1115
1116
1117
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
1118

1119
1120
1121
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1122

1123
1124
1125
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)