flash_causal_lm.py 45.3 KB
Newer Older
1
import math
2
import os
3
import time
4
import itertools
5
6
7
import torch
import torch.distributed

8
9
import numpy as np

10
from loguru import logger
11
12
from dataclasses import dataclass
from opentelemetry import trace
13
from transformers import PreTrainedTokenizerBase
14
from typing import Optional, Tuple, List, Type, Dict
15

OlivierDehaene's avatar
OlivierDehaene committed
16
from text_generation_server.models import Model
17
from text_generation_server.utils.tokens import batch_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
18
from text_generation_server.utils.speculate import get_speculate
19
20
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
21
    Tokens,
22
23
24
    Generation,
    GeneratedText,
)
25
26
27
28
29
from text_generation_server.models.cache_manager import (
    get_cache_manager,
    set_cache_manager,
    BLOCK_SIZE,
)
30
from text_generation_server.pb import generate_pb2
31
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
32
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
33
from text_generation_server.utils.dist import MEMORY_FRACTION
34
35
36

tracer = trace.get_tracer(__name__)

37

38
39
40
41
@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
42
43
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
44
45

    # Decoder values
46
47
    input_ids: torch.Tensor
    position_ids: torch.Tensor
Nicolas Patry's avatar
Nicolas Patry committed
48
    speculative_ids: torch.Tensor
49

50
51
52
53
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
54
55
56
57
58
59
60
61
62
63
64
65
66
67

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor
    # List of tuple of ints representing the number of blocks and slots needed by each sequence
    needed_blocks_slots: Optional[List[Tuple[int, int]]]

    # Set in prefill by the CacheManager
    # list of length b of list of length s_i // block_size
    block_tables: Optional[List[List[int]]]
68
    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
69
70
71
72
    block_tables_tensor: Optional[torch.Tensor]
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
    slots: Optional[torch.Tensor]

73
74
    max_seqlen: int

75
76
77
78
79
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

80
81
    # All tokens
    all_input_ids: List[List[int]]
82
    all_input_ids_tensor: torch.Tensor
83
84
85

    # Lengths of all generations present in the batch
    input_lengths: List[int]
86
    input_lengths_tensor: torch.Tensor
87
88
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
89
90

    # Generation helpers
91
    next_token_chooser: HeterogeneousNextTokenChooser
92
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
93
94
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
95

96
97
98
99
    # Number of blocks in this batch
    blocks: int
    # Maximum number of blocks
    max_blocks: int
100

101
102
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
103
            id=self.batch_id,
104
            request_ids=[r.id for r in self.requests],
105
            size=len(self),
106
            max_tokens=self.blocks * BLOCK_SIZE,
107
108
109
        )

    @classmethod
110
    def batch_tokenized_inputs(cls, requests, tokenizer):
111
112
        batch_inputs = []
        max_truncation = 0
113
        for r in requests:
114
115
116
117
118
119
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]
120
        return batch_tokenized_inputs
121

122
123
124
125
126
127
128
129
130
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
131
        position_ids = []
Nicolas Patry's avatar
Nicolas Patry committed
132
        speculative_ids = []
133
        cu_seqlen_prefill = [0]
134
135
136
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
137
138

        input_lengths = []
139
140
        prefix_offsets = []
        read_offsets = []
141
        all_input_ids = []
142
        requests_idx_mapping = {}
143

144
145
146
147
148
149
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

150
        next_token_chooser_parameters = []
151
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
152
        top_n_tokens = []
153
154
155

        # Cumulative length
        cumulative_length = 0
156
        cumulative_max_length = 0
157
        prefill_out_cumulative_length = 0
158

159
160
        blocks = 0
        max_seqlen = 0
161
        max_length = 0
162
        max_blocks = 0
163

164
        # Parse batch
165
166
167
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
168
169
170
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

171
            tokenized_input = tokenized_input[-r.truncate :]
172

173
174
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
175

176
            prefix_offsets.append(input_length - 5)
177
            read_offsets.append(input_length)
178

179
            all_input_ids.append(tokenized_input)
180
181

            # Position ids
182
183
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
184
185

            # Add cumulative lengths of all previous inputs
186
            cu_seqlen_prefill.append(cumulative_length + input_length)
187

188
            next_token_chooser_parameters.append(r.parameters)
189

190
191
192
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
193
            max_new_tokens = stopping_criteria.max_new_tokens
194
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
195
            top_n_tokens.append(r.top_n_tokens)
196

197
198
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
199
200
            speculative_length = get_speculate()
            total_tokens = input_length + max_new_tokens - 1 + speculative_length
201
202
203
204
205
206
207
208
209
210
211
212
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
            blocks += needed_blocks
            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

233
234
            # Update
            cumulative_length += input_length
235
236
237
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
OlivierDehaene's avatar
OlivierDehaene committed
238
239
240
            max_length = max(
                max_length, input_length + max_new_tokens + speculative_length
            )
241
242

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
243
            next_token_chooser_parameters, dtype, device, tokenizer
244
        )
245
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
246
247
248
249
250
251
252

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
253

254
255
256
257
258
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

259
260
261
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
262
            slot_indices = torch.cat(slot_indices)
263
264
265
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
266
            slot_indices = slot_indices[0]
267

268
269
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
270
271
272
        )
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
273
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
274
275
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
276
        )
277

278
279
        if all_prefill_logprobs:
            prefill_head_indices = None
280
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
281
        elif no_prefill_logprobs:
282
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
283
284
285
286
287
288
289
290
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
Nicolas Patry's avatar
Nicolas Patry committed
291
292
293
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
294

295
296
297
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
298
            requests_idx_mapping=requests_idx_mapping,
299
300
            input_ids=input_ids,
            position_ids=position_ids,
301
            cu_seqlen_prefill=cu_seqlen_prefill,
302
303
304
305
306
307
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
308
            max_seqlen=max_seqlen,
309
310
311
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
312
            input_lengths=input_lengths,
313
            input_lengths_tensor=input_lengths_tensor,
314
315
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
316
            all_input_ids=all_input_ids,
317
318
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
319
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
320
321
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
322
323
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
324
            speculative_ids=None,
325
326
        )

327
    @tracer.start_as_current_span("filter")
328
329
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
330
331
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
332
        if len(request_ids) == len(self):
333
334
            return self

335
        device = self.input_ids.device
336

337
338
339
        # New values after filtering
        requests_idx_mapping = {}

340
341
342
        # Used to index into tensors
        indices = []

343
344
345
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
346
347
        )

348
        # Create on CPU to only move to GPU once instead of at every copy
349
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
350
351
        max_seqlen = 0

352
        requests = []
353
354
        start_slots = []
        block_tables = []
355
356
        all_input_ids = []

357
        input_lengths = []
358
359
        prefix_offsets = []
        read_offsets = []
360

361
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
362
        top_n_tokens = []
363

364
365
366
367
368
        blocks = 0
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

369
370
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
371
            indices.append(idx)
372
373
374
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
375
376
377
378

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
379

380
381
382
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
383
384
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
385

386
387
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
388

Nicolas Patry's avatar
Nicolas Patry committed
389
390
            top_n_tokens.append(self.top_n_tokens[idx])

391
            remaining_tokens = (
392
393
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
394

395
396
397
398
399
            request_block_table = self.block_tables[idx]
            blocks += len(request_block_table)
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

400
            # Copy to tensor (CPU)
401
            slot_indices[i] = cumulative_max_length + request_input_length - 1
402
403

            # Set slice
404
405
406
407
408
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
409
410
411
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
412

413
414
415
416
417
418
419
420
421
            max_blocks = max(max_blocks, len(request_block_table))

        block_indices_to_free = []
        # Iterate on all requests
        for i, r in enumerate(self.requests):
            # Filter requests that are not part of the new batch
            if r.id not in requests_idx_mapping.keys():
                block_indices_to_free.extend(self.block_tables[i])
        # Free blocks
422
        get_cache_manager().free(block_indices_to_free)
423
424
425
        # Needed to avoid dropping blocks when the batches will go out of scope
        self.block_tables = None

426
427
428
429
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
430
431
432
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
433
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
434
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
OlivierDehaene's avatar
OlivierDehaene committed
435
436
437
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
438
439

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
440

441
        # Move to GPU now that we have the whole tensor
442
        slot_indices = slot_indices.to(device)
443

444
        return type(self)(
445
446
447
448
449
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
450
            cu_seqlen_prefill=None,
451
452
453
454
455
456
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
457
            max_seqlen=max_seqlen,
458
459
460
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
461
            input_lengths=input_lengths,
462
            input_lengths_tensor=input_lengths_tensor,
463
464
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
465
466
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
467
            next_token_chooser=next_token_chooser,
468
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
469
470
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
471
472
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
473
            speculative_ids=speculative_ids,
474
475
476
477
478
479
480
481
482
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

483
484
485
486
487
488
489
490
491
492
        blocks = 0
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
            blocks += b.blocks
OlivierDehaene's avatar
OlivierDehaene committed
493
494
495
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
496
497
498
499
500
501
502
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
503
                    + speculative_length
504
505
506
507
508
509
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
510
511
512

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
513
514
515
516
517
518
519
520
521
522
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
523
        )
Nicolas Patry's avatar
Nicolas Patry committed
524
525
526
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
527

528
529
        start_slots = []
        block_tables = []
530
531
532
        all_input_ids = []

        input_lengths = []
533
534
        prefix_offsets = []
        read_offsets = []
535

536
        next_token_chooser_parameters = []
537
        fsm_grammar_states = []
538
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
539
        top_n_tokens = []
540

541
        # Cumulative length
542
        cumulative_batch_size = 0
543
        cumulative_slots = 0
544
545
546

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
547
548
549
550
551
552
553
554

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

555
556
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
557
558
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
559
560
561
562

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
563
564
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
565
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
566
            slots[slots_start_index:slots_end_index] = batch.slots
567

568
569
570
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
571

572
573
574
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
575

576
577
578
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
579
580
            all_input_ids.extend(batch.all_input_ids)

581
            input_lengths.extend(batch.input_lengths)
582
583
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
584

585
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
586
            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
587
588
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
589
590
            top_n_tokens.extend(batch.top_n_tokens)

591
            # Update
592
            cumulative_batch_size += len(batch)
593
            cumulative_slots += len(batch.slots)
594

595
        start_slots = torch.concat(start_slots)
596

597
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
598
599
600
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
drbh's avatar
drbh committed
601
            tokenizer=batches[0].next_token_chooser.tokenizer,
602
            fsm_grammar_states=fsm_grammar_states,
603
604
        )

OlivierDehaene's avatar
OlivierDehaene committed
605
606
607
608
609
        speculative_ids = (
            torch.cat([b.speculative_ids for b in batches], dim=0)
            if batches[0].speculative_ids is not None
            else None
        )
Nicolas Patry's avatar
Nicolas Patry committed
610

611
612
613
        # Needed to avoid dropping blocks when the batches will go out of scope
        for b in batches:
            b.block_tables = None
614
            del b
615

616
        return cls(
617
618
            batch_id=batches[0].batch_id,
            requests=requests,
619
            requests_idx_mapping=requests_idx_mapping,
620
621
            input_ids=input_ids,
            position_ids=position_ids,
622
            cu_seqlen_prefill=None,
623
624
625
626
627
628
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
629
            max_seqlen=max_seqlen,
630
631
632
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
633
            input_lengths=input_lengths,
634
            input_lengths_tensor=input_lengths_tensor,
635
636
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
637
638
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
639
            next_token_chooser=next_token_chooser,
640
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
641
642
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
643
644
            blocks=blocks,
            max_blocks=max_blocks,
OlivierDehaene's avatar
OlivierDehaene committed
645
            speculative_ids=speculative_ids,
646
647
        )

648
649
650
    def __del__(self):
        if self.block_tables is not None and self.block_tables:
            # Free blocks
651
652
653
            get_cache_manager().free(
                list(itertools.chain.from_iterable(self.block_tables))
            )
654

655
656
657
658
659
660
661
    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
662
663
664
665
666
667
668
669
670
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        num_layers: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
671
        sliding_window: Optional[int] = None,
672
    ):
673
674
675
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
676

677
678
        self.cuda_graphs = {}

679
        super(FlashCausalLM, self).__init__(
680
            model=model,
681
682
683
684
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
685
686
            rank=rank,
            world_size=world_size,
687
            sliding_window=sliding_window,
688
689
690
691
692
693
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
        slots = torch.arange(bs, dtype=torch.int32, device=self.device)
        input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
        block_tables = (
            torch.arange(max_bt, dtype=torch.int32, device=self.device)
            .repeat(bs)
            .reshape((bs, max_bt))
        )
        kv_cache = get_cache_manager().kv_cache

        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "kv_cache": kv_cache,
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths,
        }
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs]["graph"] = graph

        torch.cuda.synchronize()
        # Run once outside to warmup
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=None,
            kv_cache=kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
            max_s=max_s,
            lm_head_indices=None,
        )
        torch.cuda.synchronize()

        with torch.cuda.graph(graph, pool=MEM_POOL):
733
            logits, speculative_logits = self.model.forward(
734
735
736
737
738
739
740
741
742
743
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                lm_head_indices=None,
            )
744
745
            self.cuda_graphs[bs]["logits"] = logits
            self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
746
747
        torch.cuda.synchronize()

748
    def warmup(self, batch: FlashCausalLMBatch):
749
        # The warmup batch is the biggest batch we could ever receive
750
751
        torch.cuda.empty_cache()
        try:
752
            cache_manager = set_cache_manager(
753
                batch.blocks,
754
755
756
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
757
                self.sliding_window is not None,
758
759
760
                self.dtype,
                self.device,
            )
761
762
            max_bt = batch.max_blocks
            max_s = max_bt * get_cache_manager().block_size
763
            _, batch, _ = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
764
        except torch.cuda.OutOfMemoryError as e:
765
            raise RuntimeError(
766
767
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
768
            ) from e
769
770
771

        torch.cuda.synchronize(self.device)

772
773
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
774
775
776
777
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

778
779
780
781
782
783
        total_free_memory, _ = torch.cuda.mem_get_info(self.device)
        total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory

        free_memory = max(
            0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
        )
784
785

        num_blocks = (
786
787
            # Leave 5% for some wiggle room
            int((free_memory * 0.95) // total_cache_size)
788
            # Add batch.blocks as we allocated it above, so it is included in the peak memory.
789
            + cache_manager.num_blocks
790
791
        )

792
        del batch
793
        del cache_manager
794

795
        set_cache_manager(
796
797
798
799
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
800
            self.sliding_window is not None,
801
802
803
804
            self.dtype,
            self.device,
        )

805
        if CUDA_GRAPHS:
806
            try:
807
                logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
808
                # Warmup cuda graphs
809
                for bs in CUDA_GRAPHS:
810
811
812
813
814
                    if self.speculate is None or self.speculate + 1 <= bs:
                        self.cuda_graph_warmup(bs, max_s, max_bt)
            except Exception:
                logger.exception(f"Decode cuda graph warmup failed")

815
        return int(num_blocks * BLOCK_SIZE)
816

817
818
819
    def forward(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
820
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
821
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
822
823
824
825
826
827
828
829
830
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
831
832
833

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
834
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
835
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
836
837
838
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
839
840
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
841
842
843
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
844
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
OlivierDehaene's avatar
OlivierDehaene committed
845
846
847
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
848
849

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
850
851
852
853
854
855
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
856
857
858
859
860
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
861
862
863
864
865
866
867
868
869
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
870

871
872
873
874
875
876
877
878
879
880
881
882
        bs = input_ids.shape[0]
        padded_bs = bs
        if bs == 3:
            padded_bs = 4
        elif 3 < bs <= 8:
            padded_bs = 8
        elif bs > 8:
            padded_bs = (bs + 7) // 8 * 8

        # Try to find an associated cuda graph
        cuda_graph = self.cuda_graphs.get(padded_bs, None)

drbh's avatar
drbh committed
883
884
885
886
887
        if (
            cu_seqlen_prefill is not None
            or cuda_graph is None
            or batch.speculative_ids is not None
        ):
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
            return self.model.forward(
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=cu_seqlen_prefill,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                lm_head_indices=lm_head_indices,
            )

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
        cuda_graph["block_tables"][
            : block_tables.shape[0], : block_tables.shape[1]
        ] = block_tables
        cuda_graph["slots"].fill_(-1)
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths

        # Replay the graph
        cuda_graph["graph"].replay()
        # Slice output to the correct shape
915
916
917
918
919
920
921
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
922
923
924
925

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
926
927
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
928
        prefill = batch.cu_seqlen_prefill is not None
929
        prefill_logprobs = batch.prefill_next_token_indices is not None
930

931
932
        if batch.needed_blocks_slots:
            # Allocate blocks to this batch
933
934
935
936
937
938
939
940
941
942
            block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
                batch.needed_blocks_slots,
                batch.blocks,
                batch.max_blocks,
                batch.input_ids.device,
            )
            batch.needed_blocks_slots = None
            batch.block_tables = block_tables
            batch.block_tables_tensor = block_tables_tensor
            batch.slots = slots
943

944
        try:
945
            out, speculative_logits = self.forward(batch)
946
947
948
        except Exception as e:
            del batch
            raise e
949

950
951
        if prefill:
            next_token_logits = (
952
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
953
            )
Nicolas Patry's avatar
Nicolas Patry committed
954
955
            if speculative_logits is not None:
                speculative_logits = (
OlivierDehaene's avatar
OlivierDehaene committed
956
957
958
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
Nicolas Patry's avatar
Nicolas Patry committed
959
                )
960
961
962
        else:
            next_token_logits = out

Nicolas Patry's avatar
Nicolas Patry committed
963
        speculate = get_speculate()
OlivierDehaene's avatar
OlivierDehaene committed
964
965
966
967
968
969
970
971
972
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen],
            next_token_logits,
Nicolas Patry's avatar
Nicolas Patry committed
973
            speculate,
OlivierDehaene's avatar
OlivierDehaene committed
974
975
            batch.speculative_ids,
            speculative_logits,
976
977
        )

Nicolas Patry's avatar
Nicolas Patry committed
978
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
979
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
980
981
        )

982
        if prefill:
983
            if len(batch) > 1 and prefill_logprobs:
984
985
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
986
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
987
988

            next_position_ids = batch.position_ids.new_empty(len(batch))
989
990
991
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
992
993
994
995
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

996
997
998
999
1000
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
1001
        stopped = True
1002
1003

        # Zipped iterator
OlivierDehaene's avatar
OlivierDehaene committed
1004
        iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
1005

1006
1007
1008
1009
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

1010
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1011
        index = 0
OlivierDehaene's avatar
OlivierDehaene committed
1012
        for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
1013
            # Indexing metadata
1014
1015
1016
            start_index = cumulative_length
            end_index = cumulative_length + input_length

1017
            if prefill:
1018
1019
1020
1021
1022
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

1023
1024
1025
1026
1027
1028
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
1029
1030
                if prefill_logprobs:
                    if len(batch) > 1:
drbh's avatar
drbh committed
1031
1032
1033
                        prefill_tokens_indices[out_start_index : out_end_index - 1] = (
                            batch.input_ids[start_index + 1 : start_index + out_length]
                        )
1034
1035
1036
1037
1038
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
1039

Nicolas Patry's avatar
Nicolas Patry committed
1040
1041
1042
            for j in range(n_accepted_ids):
                batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
                index += 1
1043
1044
1045

            cumulative_length += input_length

drbh's avatar
drbh committed
1046
        # Update values
Nicolas Patry's avatar
Nicolas Patry committed
1047
1048
1049
1050
1051
        batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
        batch.speculative_ids = speculative_ids
        batch.position_ids = next_position_ids + accepted_ids
        batch.input_lengths_tensor += accepted_ids
        batch.slot_indices += accepted_ids
1052

1053
        if prefill and prefill_logprobs:
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
1064
        next_token_ids = next_input_ids.tolist()
1065
1066
        accepted_ids = accepted_ids.tolist()
        start_decode = time.time_ns()
1067
1068
1069
1070
1071

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
1072
1073
            batch.prefix_offsets,
            batch.read_offsets,
1074
1075
            batch.stopping_criterias,
            batch.all_input_ids,
1076
1077
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
1078
            batch.top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1079
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1080
1081
            batch_top_token_ids,
            batch_top_token_logprobs,
1082
1083
1084
        )

        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1085
        index = 0
1086
1087
1088
        for i, (
            request,
            input_length,
1089
1090
            prefix_offset,
            read_offset,
1091
1092
            stopping_criteria,
            all_input_ids,
1093
1094
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
1095
            top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1096
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1097
1098
            top_token_ids,
            top_token_logprobs,
1099
        ) in enumerate(iterator):
1100
            # Append next token to all tokens
Nicolas Patry's avatar
Nicolas Patry committed
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
            next_token_texts = []
            left = 0

            current_stopped = False
            for j in range(index, index + n_accepted_ids):
                # Generated token
                next_token_id = next_token_ids[j]
                all_input_ids.append(next_token_id)
                next_token_text, prefix_offset, read_offset = self.decode_token(
                    all_input_ids,
                    prefix_offset,
                    read_offset,
                )
                next_token_texts.append(next_token_text)
1115

Nicolas Patry's avatar
Nicolas Patry committed
1116
1117
1118
1119
                stop, reason = stopping_criteria(
                    next_token_id,
                    next_token_text,
                )
1120

Nicolas Patry's avatar
Nicolas Patry committed
1121
1122
1123
1124
1125
1126
1127
                if stop:
                    left = index + n_accepted_ids - j - 1
                    current_stopped = True
                    break
                else:
                    current_stopped = False
            stopped = stopped and current_stopped
1128

OlivierDehaene's avatar
OlivierDehaene committed
1129
1130
1131
1132
            _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
            _next_token_logprobs = next_token_logprobs[
                index : index + n_accepted_ids - left
            ]
Nicolas Patry's avatar
Nicolas Patry committed
1133
            index += n_accepted_ids
1134

1135
1136
1137
1138
1139
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
1140
1141
                    output_text, _, _ = self.decode_token(
                        all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1142
1143
1144
1145
1146
1147
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
1148
1149
                    )
                    generated_text = GeneratedText(
1150
1151
1152
1153
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
1154
1155
1156
1157
1158
                    )
                else:
                    generated_text = None

                # Prefill
1159
1160
1161
1162
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

1163
1164
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
1165
                        out_start_index : out_end_index - 1
1166
1167
1168
1169
1170
1171
1172
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1173
1174

                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
1175
1176
1177
1178
                        prefill_token_ids,
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
1179
1180
1181
1182
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
1183
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
1184
                    all_top_tokens = []
drbh's avatar
drbh committed
1185
                    for top_token_ids, top_token_logprobs in zip(
1186
1187
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
1188
1189
1190
1191
1192
1193
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
1194
1195
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
1196
1197
1198
1199
1200
1201
1202
1203
1204
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
1205
1206
1207
                else:
                    top_tokens = None

1208
1209
1210
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1211
1212
1213
1214
1215
1216
                    Tokens(
                        _next_token_ids,
                        _next_token_logprobs,
                        next_token_texts,
                        [nid in self.all_special_ids for nid in _next_token_ids],
                    ),
1217
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
1218
                    top_tokens,
1219
1220
                )

1221
                generations.append(generation)
1222

drbh's avatar
drbh committed
1223
1224
1225
            # accept each new token for this specific request since we may
            # have more than one new token per request with speculative decoding
            for next_token_id in _next_token_ids:
OlivierDehaene's avatar
OlivierDehaene committed
1226
1227
1228
                batch.next_token_chooser = (
                    batch.next_token_chooser.advance_grammar_single(i, next_token_id)
                )
drbh's avatar
drbh committed
1229

1230
            # Update values
1231
            batch.input_lengths[i] = input_length + n_accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1232
1233
            if batch.input_lengths[i] > batch.max_seqlen:
                batch.max_seqlen = batch.input_lengths[i]
1234
1235
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1236
1237
            batch.all_input_ids[i] = all_input_ids

1238
1239
1240
        if stopped:
            del batch
            # No need to return a batch if we know that all requests stopped
1241
1242
1243
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
1244

1245
1246
1247
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1248

1249
1250
1251
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)