flash_causal_lm.py 39.3 KB
Newer Older
1
2
import math
import itertools
Nicolas Patry's avatar
Nicolas Patry committed
3
from text_generation_server.utils.tokens import batch_top_tokens
4
5
6
import torch
import torch.distributed

7
8
import numpy as np

9
10
from dataclasses import dataclass
from opentelemetry import trace
11
from transformers import PreTrainedTokenizerBase
12
from typing import Optional, Tuple, List, Type, Union, Dict
13

Nicolas Patry's avatar
Nicolas Patry committed
14
15
from text_generation_server.models import Model 
from text_generation_server.utils.speculate import get_speculate
16
17
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
18
    Tokens,
19
20
21
    Generation,
    GeneratedText,
)
22
23
24
25
26
from text_generation_server.models.cache_manager import (
    get_cache_manager,
    set_cache_manager,
    BLOCK_SIZE,
)
27
from text_generation_server.pb import generate_pb2
28
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
29
from text_generation_server.utils.dist import MEMORY_FRACTION
30
31
32
33
34
35
36
37

tracer = trace.get_tracer(__name__)


@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
38
39
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
40
41

    # Decoder values
42
43
    input_ids: torch.Tensor
    position_ids: torch.Tensor
Nicolas Patry's avatar
Nicolas Patry committed
44
    speculative_ids: torch.Tensor
45

46
47
48
49
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor
    # List of tuple of ints representing the number of blocks and slots needed by each sequence
    needed_blocks_slots: Optional[List[Tuple[int, int]]]

    # Set in prefill by the CacheManager
    # list of length b of list of length s_i // block_size
    block_tables: Optional[List[List[int]]]
    # tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
    block_tables_tensor: Optional[torch.Tensor]
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
    slots: Optional[torch.Tensor]

69
70
    max_seqlen: int

71
72
73
74
75
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

76
77
    # All tokens
    all_input_ids: List[List[int]]
78
    all_input_ids_tensor: torch.Tensor
79
80
81

    # Lengths of all generations present in the batch
    input_lengths: List[int]
82
    input_lengths_tensor: torch.Tensor
83
84
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
85
86

    # Generation helpers
87
    next_token_chooser: HeterogeneousNextTokenChooser
88
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
89
90
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
91

92
93
94
95
    # Number of blocks in this batch
    blocks: int
    # Maximum number of blocks
    max_blocks: int
96

97
98
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
99
            id=self.batch_id,
100
            request_ids=[r.id for r in self.requests],
101
            size=len(self),
102
            max_tokens=self.blocks * BLOCK_SIZE,
103
104
105
106
107
108
109
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
110
        dtype: torch.dtype,
111
        device: torch.device,
112
    ) -> "FlashCausalLMBatch":
113
114
115
116
117
118
119
120
121
122
        batch_inputs = []
        max_truncation = 0
        for r in pb.requests:
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]

123
        position_ids = []
Nicolas Patry's avatar
Nicolas Patry committed
124
        speculative_ids = []
125
        cu_seqlen_prefill = [0]
126
127
128
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
129
130

        input_lengths = []
131
132
        prefix_offsets = []
        read_offsets = []
133
        all_input_ids = []
134
        requests_idx_mapping = {}
135

136
137
138
139
140
141
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

142
        next_token_chooser_parameters = []
143
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
144
        top_n_tokens = []
145
146
147

        # Cumulative length
        cumulative_length = 0
148
        cumulative_max_length = 0
149
        prefill_out_cumulative_length = 0
150

151
152
        blocks = 0
        max_seqlen = 0
153
        max_length = 0
154
        max_blocks = 0
155

156
        # Parse batch
157
158
159
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
160
161
162
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

163
            tokenized_input = tokenized_input[-r.truncate :]
164

165
166
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
167

Nicolas Patry's avatar
Nicolas Patry committed
168
169


170
            prefix_offsets.append(input_length - 5)
171
            read_offsets.append(input_length)
172

173
            all_input_ids.append(tokenized_input)
174
175

            # Position ids
176
177
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
178
179

            # Add cumulative lengths of all previous inputs
180
            cu_seqlen_prefill.append(cumulative_length + input_length)
181

182
            next_token_chooser_parameters.append(r.parameters)
183

184
185
186
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
187
            max_new_tokens = stopping_criteria.max_new_tokens
188
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
189
            top_n_tokens.append(r.top_n_tokens)
190

191
192
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
193
194
            speculative_length = get_speculate()
            total_tokens = input_length + max_new_tokens - 1 + speculative_length
195
196
197
198
199
200
201
202
203
204
205
206
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
            blocks += needed_blocks
            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

227
228
            # Update
            cumulative_length += input_length
229
230
231
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
Nicolas Patry's avatar
Nicolas Patry committed
232
            max_length = max(max_length, input_length + max_new_tokens + speculative_length)
233
234
235
236

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
            next_token_chooser_parameters, dtype, device
        )
237
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
238
239
240
241
242
243
244

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
245

246
247
248
249
250
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

251
252
253
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
254
            slot_indices = torch.cat(slot_indices)
255
256
257
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
258
            slot_indices = slot_indices[0]
259

260
261
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
262
263
264
        )
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
265
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
266
267
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
268
        )
269

270
271
        if all_prefill_logprobs:
            prefill_head_indices = None
272
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
273
        elif no_prefill_logprobs:
274
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
275
276
277
278
279
280
281
282
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
Nicolas Patry's avatar
Nicolas Patry committed
283
284
285
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
286

287
288
289
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
290
            requests_idx_mapping=requests_idx_mapping,
291
292
            input_ids=input_ids,
            position_ids=position_ids,
293
            cu_seqlen_prefill=cu_seqlen_prefill,
294
295
296
297
298
299
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
300
            max_seqlen=max_seqlen,
301
302
303
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
304
            input_lengths=input_lengths,
305
            input_lengths_tensor=input_lengths_tensor,
306
307
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
308
            all_input_ids=all_input_ids,
309
310
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
311
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
312
313
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
314
315
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
316
            speculative_ids=None,
317
318
        )

319
    @tracer.start_as_current_span("filter")
320
321
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
322
323
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
324
        if len(request_ids) == len(self):
325
326
            return self

327
        device = self.input_ids.device
328

329
330
331
        # New values after filtering
        requests_idx_mapping = {}

332
333
334
        # Used to index into tensors
        indices = []

335
336
337
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
338
339
        )

340
        # Create on CPU to only move to GPU once instead of at every copy
341
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
342
343
        max_seqlen = 0

344
        requests = []
345
346
        start_slots = []
        block_tables = []
347
348
        all_input_ids = []

349
        input_lengths = []
350
351
        prefix_offsets = []
        read_offsets = []
352

353
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
354
        top_n_tokens = []
355

356
357
358
359
360
        blocks = 0
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

361
362
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
363
            indices.append(idx)
364
365
366
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
367
368
369
370

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
371

372
373
374
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
375
376
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
377

378
379
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
380

Nicolas Patry's avatar
Nicolas Patry committed
381
382
            top_n_tokens.append(self.top_n_tokens[idx])

383
            remaining_tokens = (
384
385
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
386

387
388
389
390
391
            request_block_table = self.block_tables[idx]
            blocks += len(request_block_table)
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

392
            # Copy to tensor (CPU)
393
            slot_indices[i] = cumulative_max_length + request_input_length - 1
394
395

            # Set slice
396
397
398
399
400
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
401
402
403
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
404

405
406
407
408
409
410
411
412
413
            max_blocks = max(max_blocks, len(request_block_table))

        block_indices_to_free = []
        # Iterate on all requests
        for i, r in enumerate(self.requests):
            # Filter requests that are not part of the new batch
            if r.id not in requests_idx_mapping.keys():
                block_indices_to_free.extend(self.block_tables[i])
        # Free blocks
414
        get_cache_manager().free(block_indices_to_free)
415
416
417
        # Needed to avoid dropping blocks when the batches will go out of scope
        self.block_tables = None

418
419
420
421
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
422
423
424
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
425
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
426
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
Nicolas Patry's avatar
Nicolas Patry committed
427
        speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None
428
429

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
430

431
        # Move to GPU now that we have the whole tensor
432
        slot_indices = slot_indices.to(device)
433

434
        return type(self)(
435
436
437
438
439
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
440
            cu_seqlen_prefill=None,
441
442
443
444
445
446
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
447
            max_seqlen=max_seqlen,
448
449
450
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
451
            input_lengths=input_lengths,
452
            input_lengths_tensor=input_lengths_tensor,
453
454
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
455
456
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
457
            next_token_chooser=next_token_chooser,
458
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
459
460
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
461
462
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
463
            speculative_ids=speculative_ids,
464
465
466
467
468
469
470
471
472
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

473
474
475
476
477
478
479
480
481
482
        blocks = 0
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
            blocks += b.blocks
Nicolas Patry's avatar
Nicolas Patry committed
483
            speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
484
485
486
487
488
489
490
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
491
                    + speculative_length
492
493
494
495
496
497
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
498
499
500

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
501
502
503
504
505
506
507
508
509
510
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
511
        )
Nicolas Patry's avatar
Nicolas Patry committed
512
513
514
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
515

516
517
        start_slots = []
        block_tables = []
518
519
520
        all_input_ids = []

        input_lengths = []
521
522
        prefix_offsets = []
        read_offsets = []
523

524
        next_token_chooser_parameters = []
525
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
526
        top_n_tokens = []
527

528
        # Cumulative length
529
        cumulative_batch_size = 0
530
        cumulative_slots = 0
531
532
533

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
534
535
536
537
538
539
540
541

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

542
543
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
544
545
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
546
547
548
549

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
550
551
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
552
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
553
            slots[slots_start_index:slots_end_index] = batch.slots
554

555
556
557
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
558

559
560
561
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
562

563
564
565
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
566
567
            all_input_ids.extend(batch.all_input_ids)

568
            input_lengths.extend(batch.input_lengths)
569
570
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
571

572
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
573
574
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
575
576
            top_n_tokens.extend(batch.top_n_tokens)

577
            # Update
578
            cumulative_batch_size += len(batch)
579
            cumulative_slots += len(batch.slots)
580

581
        start_slots = torch.concat(start_slots)
582

583
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
584
585
586
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
587
588
        )

Nicolas Patry's avatar
Nicolas Patry committed
589
590
        speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None

591
592
593
        # Needed to avoid dropping blocks when the batches will go out of scope
        for b in batches:
            b.block_tables = None
594
            del b
595

596
        return cls(
597
598
            batch_id=batches[0].batch_id,
            requests=requests,
599
            requests_idx_mapping=requests_idx_mapping,
600
601
            input_ids=input_ids,
            position_ids=position_ids,
602
            cu_seqlen_prefill=None,
603
604
605
606
607
608
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
609
            max_seqlen=max_seqlen,
610
611
612
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
613
            input_lengths=input_lengths,
614
            input_lengths_tensor=input_lengths_tensor,
615
616
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
617
618
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
619
            next_token_chooser=next_token_chooser,
620
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
621
622
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
623
624
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
625
            speculative_ids=speculative_ids
626
627
        )

628
629
630
    def __del__(self):
        if self.block_tables is not None and self.block_tables:
            # Free blocks
631
632
633
            get_cache_manager().free(
                list(itertools.chain.from_iterable(self.block_tables))
            )
634

635
636
637
638
639
640
641
    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
642
643
644
645
646
647
648
649
650
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        num_layers: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
651
        sliding_window: Optional[int] = None,
652
    ):
653
654
655
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
656
657

        super(FlashCausalLM, self).__init__(
658
            model=model,
659
660
661
662
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
663
664
            rank=rank,
            world_size=world_size,
665
            sliding_window=sliding_window,
666
667
668
669
670
671
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

672
    def warmup(self, batch: FlashCausalLMBatch):
673
674
        torch.cuda.empty_cache()
        try:
675
            cache_manager = set_cache_manager(
676
                batch.blocks,
677
678
679
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
680
                self.sliding_window is not None,
681
682
683
684
                self.dtype,
                self.device,
            )
            _, batch = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
685
        except torch.cuda.OutOfMemoryError as e:
686
            raise RuntimeError(
687
688
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
689
            ) from e
690
691
692

        torch.cuda.synchronize(self.device)

693
694
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
695
696
697
698
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

699
700
701
702
703
704
        total_free_memory, _ = torch.cuda.mem_get_info(self.device)
        total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory

        free_memory = max(
            0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
        )
705
706

        num_blocks = (
707
            int(free_memory // total_cache_size)
708
            # Add batch.blocks as we allocated it above, so it is included in the peak memory.
709
            + cache_manager.num_blocks
710
711
        )

712
        del batch
713
        del cache_manager
714

715
        set_cache_manager(
716
717
718
719
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
720
            self.sliding_window is not None,
721
722
723
724
725
            self.dtype,
            self.device,
        )

        return int(num_blocks * BLOCK_SIZE)
726

727
    def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
728
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
        if batch.speculative_ids is not None:
            input_ids=batch.input_ids
            position_ids=batch.position_ids
            cu_seqlen_prefill=batch.cu_seqlen_prefill
            kv_cache=get_cache_manager().kv_cache
            block_tables=batch.block_tables_tensor
            slots=batch.slots[batch.slot_indices]
            input_lengths=batch.input_lengths_tensor
            max_s=batch.max_seqlen
            lm_head_indices=batch.prefill_head_indices

            speculative_ids = batch.speculative_ids

            B, speculative_length = speculative_ids.shape 
            new_length = speculative_length + 1
            new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
            new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
            input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)

            # Add Copy the block tables for all members
            block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
            input_ids=batch.input_ids
            position_ids=batch.position_ids
            cu_seqlen_prefill=batch.cu_seqlen_prefill
            kv_cache=get_cache_manager().kv_cache
            block_tables=batch.block_tables_tensor
            slots=batch.slots[batch.slot_indices]
            input_lengths=batch.input_lengths_tensor
            max_s=batch.max_seqlen
            lm_head_indices=batch.prefill_head_indices

768
        return self.model.forward(
Nicolas Patry's avatar
Nicolas Patry committed
769
770
771
772
773
774
775
776
777
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=cu_seqlen_prefill,
            kv_cache=kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
            max_s=max_s,
            lm_head_indices=lm_head_indices,
778
779
780
781
782
783
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
784
        prefill = batch.cu_seqlen_prefill is not None
785
        prefill_logprobs = batch.prefill_next_token_indices is not None
786

787
788
        if batch.needed_blocks_slots:
            # Allocate blocks to this batch
789
790
791
792
793
794
795
796
797
798
            block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
                batch.needed_blocks_slots,
                batch.blocks,
                batch.max_blocks,
                batch.input_ids.device,
            )
            batch.needed_blocks_slots = None
            batch.block_tables = block_tables
            batch.block_tables_tensor = block_tables_tensor
            batch.slots = slots
799

800
        try:
801
            out = self.forward(batch)
802
803
804
        except Exception as e:
            del batch
            raise e
805

Nicolas Patry's avatar
Nicolas Patry committed
806
807
808
809
810
811
        if isinstance(out, tuple):
            out, speculative_logits = out
        else:
            speculative_logits = None


812
813
        if prefill:
            next_token_logits = (
814
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
815
            )
Nicolas Patry's avatar
Nicolas Patry committed
816
817
818
819
            if speculative_logits is not None:
                speculative_logits = (
                    speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits
                )
820
821
822
        else:
            next_token_logits = out

Nicolas Patry's avatar
Nicolas Patry committed
823
824
        next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
825
826
        )

Nicolas Patry's avatar
Nicolas Patry committed
827
828
829
830
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
        )

Nicolas Patry's avatar
Nicolas Patry committed
831
        speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
832
        if prefill:
833
            if len(batch) > 1 and prefill_logprobs:
834
835
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
836
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
837
838

            next_position_ids = batch.position_ids.new_empty(len(batch))
839
840
841
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
842
843
844
845
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

846
847
848
849
850
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
851
        stopped = True
852
853
854
855
856

        # Zipped iterator
        iterator = zip(
            batch.input_lengths,
            batch.all_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
857
            accepted_ids
858
859
        )

860
861
862
863
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

864
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
865
        index = 0
866
867
868
        for i, (
            input_length,
            all_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
869
            n_accepted_ids
870
        ) in enumerate(iterator):
871
            # Indexing metadata
872
873
874
            start_index = cumulative_length
            end_index = cumulative_length + input_length

875
            if prefill:
876
877
878
879
880
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

881
882
883
884
885
886
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
887
888
889
890
891
892
893
894
895
896
                if prefill_logprobs:
                    if len(batch) > 1:
                        prefill_tokens_indices[
                            out_start_index : out_end_index - 1
                        ] = batch.input_ids[start_index + 1 : start_index + out_length]
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
897

Nicolas Patry's avatar
Nicolas Patry committed
898
899
900
            for j in range(n_accepted_ids):
                batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
                index += 1
901
902
903

            cumulative_length += input_length

Nicolas Patry's avatar
Nicolas Patry committed
904
905
906
907
908
909

        batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
        batch.speculative_ids = speculative_ids
        batch.position_ids = next_position_ids + accepted_ids
        batch.input_lengths_tensor += accepted_ids
        batch.slot_indices += accepted_ids
910

911
        if prefill and prefill_logprobs:
912
913
914
915
916
917
918
919
920
921
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
922
        next_token_ids = next_input_ids.tolist()
923
924
925
926
927

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
928
929
            batch.prefix_offsets,
            batch.read_offsets,
930
931
            batch.stopping_criterias,
            batch.all_input_ids,
932
933
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
934
            batch.top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
935
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
936
937
            batch_top_token_ids,
            batch_top_token_logprobs,
938
939
940
        )

        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
941
        index = 0
942
943
944
        for i, (
            request,
            input_length,
945
946
            prefix_offset,
            read_offset,
947
948
            stopping_criteria,
            all_input_ids,
949
950
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
951
            top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
952
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
953
954
            top_token_ids,
            top_token_logprobs,
955
        ) in enumerate(iterator):
956
            # Append next token to all tokens
Nicolas Patry's avatar
Nicolas Patry committed
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
            next_token_texts = []
            left = 0
            before = stopping_criteria.current_tokens

            current_stopped = False
            for j in range(index, index + n_accepted_ids):
                # Generated token
                next_token_id = next_token_ids[j]
                all_input_ids.append(next_token_id)
                next_token_text, prefix_offset, read_offset = self.decode_token(
                    all_input_ids,
                    prefix_offset,
                    read_offset,
                )
                next_token_texts.append(next_token_text)
972

Nicolas Patry's avatar
Nicolas Patry committed
973
974
975
976
                stop, reason = stopping_criteria(
                    next_token_id,
                    next_token_text,
                )
977

Nicolas Patry's avatar
Nicolas Patry committed
978
979
980
981
982
983
984
                if stop:
                    left = index + n_accepted_ids - j - 1
                    current_stopped = True
                    break
                else:
                    current_stopped = False
            stopped = stopped and current_stopped
985

Nicolas Patry's avatar
Nicolas Patry committed
986
987
988
            _next_token_ids = next_token_ids[index: index+n_accepted_ids - left]
            _next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left]
            index += n_accepted_ids
989

990
991
992
993
994
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
995
996
                    output_text, _, _ = self.decode_token(
                        all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
997
998
999
1000
1001
1002
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
1003
1004
                    )
                    generated_text = GeneratedText(
1005
1006
1007
1008
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
1009
1010
1011
1012
1013
                    )
                else:
                    generated_text = None

                # Prefill
1014
1015
1016
1017
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

1018
1019
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
1020
                        out_start_index : out_end_index - 1
1021
1022
1023
1024
1025
1026
1027
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1028
1029
1030

                    prefill_tokens = Tokens(
                        prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = []
1031
1032
1033
1034
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
1035
1036
1037
1038
1039
1040
1041
1042
1043
                if top_n_tokens > 0:
                    toptoken_texts = self.tokenizer.batch_decode(
                        top_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    special_toptokens = [
                        token_id in self.all_special_ids for token_id in top_token_ids
                    ]
Nicolas Patry's avatar
Nicolas Patry committed
1044
                    top_tokens = Tokens(
Nicolas Patry's avatar
Nicolas Patry committed
1045
1046
1047
1048
1049
1050
1051
1052
                        top_token_ids,
                        top_token_logprobs,
                        toptoken_texts,
                        special_toptokens,
                    )
                else:
                    top_tokens = None

1053
1054
1055
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1056
1057
1058
1059
1060
1061
                    Tokens(
                        _next_token_ids,
                        _next_token_logprobs,
                        next_token_texts,
                        [nid in self.all_special_ids for nid in _next_token_ids],
                    ),
1062
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
1063
                    top_tokens,
1064
1065
                )

1066
                generations.append(generation)
1067

1068
            # Update values
Nicolas Patry's avatar
Nicolas Patry committed
1069
1070
1071
            batch.input_lengths[i] = input_length + n_accepted_ids.item()
            if batch.input_lengths[i] > batch.max_seqlen:
                batch.max_seqlen = batch.input_lengths[i]
1072
1073
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1074
1075
            batch.all_input_ids[i] = all_input_ids

1076
1077
1078
1079
1080
        if stopped:
            del batch
            # No need to return a batch if we know that all requests stopped
            return generations, None

1081
1082
1083
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1084

1085
        return generations, batch