flash_causal_lm.py 46.2 KB
Newer Older
1
import math
2
import os
3
import time
4
import itertools
5
6
7
import torch
import torch.distributed

8
9
import numpy as np

10
from loguru import logger
11
12
from dataclasses import dataclass
from opentelemetry import trace
13
from transformers import PreTrainedTokenizerBase
14
from typing import Optional, Tuple, List, Type, Dict
15

OlivierDehaene's avatar
OlivierDehaene committed
16
from text_generation_server.models import Model
17
from text_generation_server.utils.tokens import batch_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
18
from text_generation_server.utils.speculate import get_speculate
19
20
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
21
    Tokens,
22
23
24
    Generation,
    GeneratedText,
)
25
26
27
28
29
from text_generation_server.models.cache_manager import (
    get_cache_manager,
    set_cache_manager,
    BLOCK_SIZE,
)
30
from text_generation_server.pb import generate_pb2
31
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
32
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
33
from text_generation_server.utils.dist import MEMORY_FRACTION
34
35

tracer = trace.get_tracer(__name__)
Nicolas Patry's avatar
Nicolas Patry committed
36
37
38
39
40
41
from text_generation_server.utils.import_utils import (
    IS_CUDA_SYSTEM,
    IS_ROCM_SYSTEM,
    IS_XPU_SYSTEM,
)

42

43
44
45
46
@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
47
48
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
49
50

    # Decoder values
51
52
    input_ids: torch.Tensor
    position_ids: torch.Tensor
Nicolas Patry's avatar
Nicolas Patry committed
53
    speculative_ids: torch.Tensor
54

55
56
57
58
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
59
60
61
62
63
64
65
66
67
68
69
70
71
72

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor
    # List of tuple of ints representing the number of blocks and slots needed by each sequence
    needed_blocks_slots: Optional[List[Tuple[int, int]]]

    # Set in prefill by the CacheManager
    # list of length b of list of length s_i // block_size
    block_tables: Optional[List[List[int]]]
73
    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
74
75
76
77
    block_tables_tensor: Optional[torch.Tensor]
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
    slots: Optional[torch.Tensor]

78
79
    max_seqlen: int

80
81
82
83
84
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

85
86
    # All tokens
    all_input_ids: List[List[int]]
87
    all_input_ids_tensor: torch.Tensor
88
89
90

    # Lengths of all generations present in the batch
    input_lengths: List[int]
91
    input_lengths_tensor: torch.Tensor
92
93
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
94
95

    # Generation helpers
96
    next_token_chooser: HeterogeneousNextTokenChooser
97
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
98
99
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
100

101
102
103
104
    # Number of blocks in this batch
    blocks: int
    # Maximum number of blocks
    max_blocks: int
105

106
107
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
108
            id=self.batch_id,
109
            request_ids=[r.id for r in self.requests],
110
            size=len(self),
111
            max_tokens=self.blocks * BLOCK_SIZE,
112
113
114
        )

    @classmethod
115
    def batch_tokenized_inputs(cls, requests, tokenizer):
116
117
        batch_inputs = []
        max_truncation = 0
118
        for r in requests:
119
120
121
122
123
124
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]
125
        return batch_tokenized_inputs
126

127
128
129
130
131
132
133
134
135
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
136
        position_ids = []
Nicolas Patry's avatar
Nicolas Patry committed
137
        speculative_ids = []
138
        cu_seqlen_prefill = [0]
139
140
141
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
142
143

        input_lengths = []
144
145
        prefix_offsets = []
        read_offsets = []
146
        all_input_ids = []
147
        requests_idx_mapping = {}
148

149
150
151
152
153
154
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

155
        next_token_chooser_parameters = []
156
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
157
        top_n_tokens = []
158
159
160

        # Cumulative length
        cumulative_length = 0
161
        cumulative_max_length = 0
162
        prefill_out_cumulative_length = 0
163

164
165
        blocks = 0
        max_seqlen = 0
166
        max_length = 0
167
        max_blocks = 0
168

169
        # Parse batch
170
171
172
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
173
174
175
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

176
            tokenized_input = tokenized_input[-r.truncate :]
177
178
179
180
181
            if (
                tokenized_input[0] == tokenizer.bos_token_id
                and tokenized_input[1] == tokenizer.bos_token_id
            ):
                tokenized_input = tokenized_input[1:]
182

183
184
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
185

186
            prefix_offsets.append(input_length - 5)
187
            read_offsets.append(input_length)
188

189
            all_input_ids.append(tokenized_input)
190
191

            # Position ids
192
193
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
194
195

            # Add cumulative lengths of all previous inputs
196
            cu_seqlen_prefill.append(cumulative_length + input_length)
197

198
            next_token_chooser_parameters.append(r.parameters)
199

200
201
202
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
203
            max_new_tokens = stopping_criteria.max_new_tokens
204
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
205
            top_n_tokens.append(r.top_n_tokens)
206

207
208
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
209
210
            speculative_length = get_speculate()
            total_tokens = input_length + max_new_tokens - 1 + speculative_length
211
212
213
214
215
216
217
218
219
220
221
222
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
            blocks += needed_blocks
            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

243
244
            # Update
            cumulative_length += input_length
245
246
247
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
OlivierDehaene's avatar
OlivierDehaene committed
248
249
250
            max_length = max(
                max_length, input_length + max_new_tokens + speculative_length
            )
251
252

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
253
            next_token_chooser_parameters, dtype, device, tokenizer
254
        )
255
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
256
257
258
259
260
261
262

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
263

264
265
266
267
268
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

269
270
271
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
272
            slot_indices = torch.cat(slot_indices)
273
274
275
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
276
            slot_indices = slot_indices[0]
277

278
279
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
280
281
282
        )
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
283
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
284
285
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
286
        )
287

288
289
        if all_prefill_logprobs:
            prefill_head_indices = None
290
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
291
        elif no_prefill_logprobs:
292
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
293
294
295
296
297
298
299
300
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
Nicolas Patry's avatar
Nicolas Patry committed
301
302
303
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
304

305
306
307
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
308
            requests_idx_mapping=requests_idx_mapping,
309
310
            input_ids=input_ids,
            position_ids=position_ids,
311
            cu_seqlen_prefill=cu_seqlen_prefill,
312
313
314
315
316
317
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
318
            max_seqlen=max_seqlen,
319
320
321
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
322
            input_lengths=input_lengths,
323
            input_lengths_tensor=input_lengths_tensor,
324
325
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
326
            all_input_ids=all_input_ids,
327
328
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
329
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
330
331
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
332
333
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
334
            speculative_ids=None,
335
336
        )

337
    @tracer.start_as_current_span("filter")
338
339
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
340
341
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
342
        if len(request_ids) == len(self):
343
344
            return self

345
        device = self.input_ids.device
346

347
348
349
        # New values after filtering
        requests_idx_mapping = {}

350
351
352
        # Used to index into tensors
        indices = []

353
354
355
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
356
357
        )

358
        # Create on CPU to only move to GPU once instead of at every copy
359
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
360
361
        max_seqlen = 0

362
        requests = []
363
364
        start_slots = []
        block_tables = []
365
366
        all_input_ids = []

367
        input_lengths = []
368
369
        prefix_offsets = []
        read_offsets = []
370

371
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
372
        top_n_tokens = []
373

374
375
376
377
378
        blocks = 0
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

379
380
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
381
            indices.append(idx)
382
383
384
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
385
386
387
388

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
389

390
391
392
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
393
394
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
395

396
397
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
398

Nicolas Patry's avatar
Nicolas Patry committed
399
400
            top_n_tokens.append(self.top_n_tokens[idx])

401
            remaining_tokens = (
402
403
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
404

405
406
407
408
409
            request_block_table = self.block_tables[idx]
            blocks += len(request_block_table)
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

410
            # Copy to tensor (CPU)
411
            slot_indices[i] = cumulative_max_length + request_input_length - 1
412
413

            # Set slice
414
415
416
417
418
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
419
420
421
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
422

423
424
425
426
427
428
429
430
431
            max_blocks = max(max_blocks, len(request_block_table))

        block_indices_to_free = []
        # Iterate on all requests
        for i, r in enumerate(self.requests):
            # Filter requests that are not part of the new batch
            if r.id not in requests_idx_mapping.keys():
                block_indices_to_free.extend(self.block_tables[i])
        # Free blocks
432
        get_cache_manager().free(block_indices_to_free)
433
434
435
        # Needed to avoid dropping blocks when the batches will go out of scope
        self.block_tables = None

436
437
438
439
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
440
441
442
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
443
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
444
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
OlivierDehaene's avatar
OlivierDehaene committed
445
446
447
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
448
449

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
450

451
        # Move to GPU now that we have the whole tensor
452
        slot_indices = slot_indices.to(device)
453

454
        return type(self)(
455
456
457
458
459
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
460
            cu_seqlen_prefill=None,
461
462
463
464
465
466
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
467
            max_seqlen=max_seqlen,
468
469
470
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
471
            input_lengths=input_lengths,
472
            input_lengths_tensor=input_lengths_tensor,
473
474
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
475
476
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
477
            next_token_chooser=next_token_chooser,
478
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
479
480
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
481
482
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
483
            speculative_ids=speculative_ids,
484
485
486
487
488
489
490
491
492
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

493
494
495
496
497
498
499
500
501
502
        blocks = 0
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
            blocks += b.blocks
OlivierDehaene's avatar
OlivierDehaene committed
503
504
505
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
506
507
508
509
510
511
512
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
513
                    + speculative_length
514
515
516
517
518
519
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
520
521
522

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
523
524
525
526
527
528
529
530
531
532
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
533
        )
Nicolas Patry's avatar
Nicolas Patry committed
534
535
536
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
537

538
539
        start_slots = []
        block_tables = []
540
541
542
        all_input_ids = []

        input_lengths = []
543
544
        prefix_offsets = []
        read_offsets = []
545

546
        next_token_chooser_parameters = []
547
        fsm_grammar_states = []
548
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
549
        top_n_tokens = []
550

551
        # Cumulative length
552
        cumulative_batch_size = 0
553
        cumulative_slots = 0
554
555
556

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
557
558
559
560
561
562
563
564

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

565
566
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
567
568
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
569
570
571
572

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
573
574
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
575
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
576
            slots[slots_start_index:slots_end_index] = batch.slots
577

578
579
580
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
581

582
583
584
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
585

586
587
588
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
589
590
            all_input_ids.extend(batch.all_input_ids)

591
            input_lengths.extend(batch.input_lengths)
592
593
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
594

595
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
596
            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
597
598
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
599
600
            top_n_tokens.extend(batch.top_n_tokens)

601
            # Update
602
            cumulative_batch_size += len(batch)
603
            cumulative_slots += len(batch.slots)
604

605
        start_slots = torch.concat(start_slots)
606

607
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
608
609
610
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
drbh's avatar
drbh committed
611
            tokenizer=batches[0].next_token_chooser.tokenizer,
612
            fsm_grammar_states=fsm_grammar_states,
613
614
        )

OlivierDehaene's avatar
OlivierDehaene committed
615
616
617
618
619
        speculative_ids = (
            torch.cat([b.speculative_ids for b in batches], dim=0)
            if batches[0].speculative_ids is not None
            else None
        )
Nicolas Patry's avatar
Nicolas Patry committed
620

621
622
623
        # Needed to avoid dropping blocks when the batches will go out of scope
        for b in batches:
            b.block_tables = None
624
            del b
625

626
        return cls(
627
628
            batch_id=batches[0].batch_id,
            requests=requests,
629
            requests_idx_mapping=requests_idx_mapping,
630
631
            input_ids=input_ids,
            position_ids=position_ids,
632
            cu_seqlen_prefill=None,
633
634
635
636
637
638
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
639
            max_seqlen=max_seqlen,
640
641
642
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
643
            input_lengths=input_lengths,
644
            input_lengths_tensor=input_lengths_tensor,
645
646
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
647
648
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
649
            next_token_chooser=next_token_chooser,
650
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
651
652
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
653
654
            blocks=blocks,
            max_blocks=max_blocks,
OlivierDehaene's avatar
OlivierDehaene committed
655
            speculative_ids=speculative_ids,
656
657
        )

658
659
660
    def __del__(self):
        if self.block_tables is not None and self.block_tables:
            # Free blocks
661
662
663
            get_cache_manager().free(
                list(itertools.chain.from_iterable(self.block_tables))
            )
664

665
666
667
668
669
670
671
    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
672
673
674
675
676
677
678
679
680
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        num_layers: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
681
        sliding_window: Optional[int] = None,
682
    ):
683
684
685
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
686

687
688
        self.cuda_graphs = {}

689
        super(FlashCausalLM, self).__init__(
690
            model=model,
691
692
693
694
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
695
696
            rank=rank,
            world_size=world_size,
697
            sliding_window=sliding_window,
698
699
700
701
702
703
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

704
705
706
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
707
        slots = torch.arange(bs, dtype=torch.int64, device=self.device)
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
        input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
        block_tables = (
            torch.arange(max_bt, dtype=torch.int32, device=self.device)
            .repeat(bs)
            .reshape((bs, max_bt))
        )
        kv_cache = get_cache_manager().kv_cache

        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "kv_cache": kv_cache,
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths,
        }
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs]["graph"] = graph

        torch.cuda.synchronize()
        # Run once outside to warmup
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=None,
            kv_cache=kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
            max_s=max_s,
            lm_head_indices=None,
        )
        torch.cuda.synchronize()

        with torch.cuda.graph(graph, pool=MEM_POOL):
743
            logits, speculative_logits = self.model.forward(
744
745
746
747
748
749
750
751
752
753
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                lm_head_indices=None,
            )
754
755
            self.cuda_graphs[bs]["logits"] = logits
            self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
756
757
        torch.cuda.synchronize()

758
    def warmup(self, batch: FlashCausalLMBatch):
759
        # The warmup batch is the biggest batch we could ever receive
760
761
762
763
        if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
            torch.cuda.empty_cache()
        elif IS_XPU_SYSTEM:
            torch.xpu.empty_cache()
764
        try:
765
            cache_manager = set_cache_manager(
766
                batch.blocks,
767
768
769
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
770
                self.sliding_window is not None,
771
772
773
                self.dtype,
                self.device,
            )
774
775
            max_bt = batch.max_blocks
            max_s = max_bt * get_cache_manager().block_size
776
            _, batch, _ = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
777
        except torch.cuda.OutOfMemoryError as e:
778
            raise RuntimeError(
779
780
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
781
            ) from e
782

783
784
785
786
        if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
            torch.cuda.synchronize(self.device)
        elif IS_XPU_SYSTEM:
            torch.xpu.synchronize(self.device)
787

788
789
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
790
791
792
793
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

794
795
        if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
            total_free_memory, _ = torch.cuda.mem_get_info(self.device)
Nicolas Patry's avatar
Nicolas Patry committed
796
797
798
            total_gpu_memory = torch.cuda.get_device_properties(
                self.device
            ).total_memory
799

800
801
802
803
804
            free_memory = max(
                0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
            )
        elif IS_XPU_SYSTEM:
            total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
Nicolas Patry's avatar
Nicolas Patry committed
805
            free_memory = int(total_gpu_memory * 0.5)
806
807
        else:
            raise NotImplementedError("FlashModel is only available on GPU")
808
809

        num_blocks = (
810
811
            # Leave 5% for some wiggle room
            int((free_memory * 0.95) // total_cache_size)
812
            # Add batch.blocks as we allocated it above, so it is included in the peak memory.
813
            + cache_manager.num_blocks
814
815
        )

816
        del batch
817
        del cache_manager
818

819
        set_cache_manager(
820
821
822
823
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
824
            self.sliding_window is not None,
825
826
827
828
            self.dtype,
            self.device,
        )

829
        if CUDA_GRAPHS:
830
            try:
831
                logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
832
                # Warmup cuda graphs
833
                for bs in CUDA_GRAPHS:
834
835
                    if self.speculate is None or self.speculate + 1 <= bs:
                        self.cuda_graph_warmup(bs, max_s, max_bt)
OlivierDehaene's avatar
OlivierDehaene committed
836
            except torch.cuda.OutOfMemoryError:
837
                logger.exception(f"Decode cuda graph warmup failed")
838
839
        else:
            logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
840

841
        return int(num_blocks * BLOCK_SIZE)
842

843
844
845
    def forward(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
846
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
847
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
848
849
850
851
852
853
854
855
856
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
857
858
859

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
860
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
861
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
862
863
864
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
865
866
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
867
868
869
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
870
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
OlivierDehaene's avatar
OlivierDehaene committed
871
872
873
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
874
875

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
876
877
878
879
880
881
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
882
883
884
885
886
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
887
888
889
890
891
892
893
894
895
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
896

897
        bs = input_ids.shape[0]
OlivierDehaene's avatar
OlivierDehaene committed
898
899
900
901
902
903
904
905
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None

        if cu_seqlen_prefill is not None or cuda_graph is None:
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
            return self.model.forward(
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=cu_seqlen_prefill,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                lm_head_indices=lm_head_indices,
            )

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
        cuda_graph["block_tables"][
            : block_tables.shape[0], : block_tables.shape[1]
        ] = block_tables
        cuda_graph["slots"].fill_(-1)
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths

        # Replay the graph
        cuda_graph["graph"].replay()
        # Slice output to the correct shape
933
934
935
936
937
938
939
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
940
941
942
943

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
944
945
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
946
        prefill = batch.cu_seqlen_prefill is not None
947
        prefill_logprobs = batch.prefill_next_token_indices is not None
948

949
950
        if batch.needed_blocks_slots:
            # Allocate blocks to this batch
951
952
953
954
955
956
957
958
959
960
            block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
                batch.needed_blocks_slots,
                batch.blocks,
                batch.max_blocks,
                batch.input_ids.device,
            )
            batch.needed_blocks_slots = None
            batch.block_tables = block_tables
            batch.block_tables_tensor = block_tables_tensor
            batch.slots = slots
961

962
        try:
963
            out, speculative_logits = self.forward(batch)
964
965
966
        except Exception as e:
            del batch
            raise e
967

968
969
        if prefill:
            next_token_logits = (
970
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
971
            )
Nicolas Patry's avatar
Nicolas Patry committed
972
973
            if speculative_logits is not None:
                speculative_logits = (
OlivierDehaene's avatar
OlivierDehaene committed
974
975
976
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
Nicolas Patry's avatar
Nicolas Patry committed
977
                )
978
979
980
        else:
            next_token_logits = out

Nicolas Patry's avatar
Nicolas Patry committed
981
        speculate = get_speculate()
OlivierDehaene's avatar
OlivierDehaene committed
982
983
984
985
986
987
988
989
990
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen],
            next_token_logits,
Nicolas Patry's avatar
Nicolas Patry committed
991
            speculate,
OlivierDehaene's avatar
OlivierDehaene committed
992
993
            batch.speculative_ids,
            speculative_logits,
994
995
        )

Nicolas Patry's avatar
Nicolas Patry committed
996
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
997
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
998
999
        )

1000
        if prefill:
1001
            if len(batch) > 1 and prefill_logprobs:
1002
1003
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
1004
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
1005
1006

            next_position_ids = batch.position_ids.new_empty(len(batch))
1007
1008
1009
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
1010
1011
1012
1013
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

1014
1015
1016
1017
1018
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
1019
        stopped = True
1020
1021

        # Zipped iterator
OlivierDehaene's avatar
OlivierDehaene committed
1022
        iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
1023

1024
1025
1026
1027
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

1028
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1029
        index = 0
OlivierDehaene's avatar
OlivierDehaene committed
1030
        for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
1031
            # Indexing metadata
1032
1033
1034
            start_index = cumulative_length
            end_index = cumulative_length + input_length

1035
            if prefill:
1036
1037
1038
1039
1040
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

1041
1042
1043
1044
1045
1046
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
1047
1048
                if prefill_logprobs:
                    if len(batch) > 1:
drbh's avatar
drbh committed
1049
1050
1051
                        prefill_tokens_indices[out_start_index : out_end_index - 1] = (
                            batch.input_ids[start_index + 1 : start_index + out_length]
                        )
1052
1053
1054
1055
1056
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
1057

Nicolas Patry's avatar
Nicolas Patry committed
1058
1059
1060
            for j in range(n_accepted_ids):
                batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
                index += 1
1061
1062
1063

            cumulative_length += input_length

drbh's avatar
drbh committed
1064
        # Update values
Nicolas Patry's avatar
Nicolas Patry committed
1065
1066
1067
1068
1069
        batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
        batch.speculative_ids = speculative_ids
        batch.position_ids = next_position_ids + accepted_ids
        batch.input_lengths_tensor += accepted_ids
        batch.slot_indices += accepted_ids
1070

1071
        if prefill and prefill_logprobs:
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
1082
        next_token_ids = next_input_ids.tolist()
1083
1084
        accepted_ids = accepted_ids.tolist()
        start_decode = time.time_ns()
1085
1086
1087
1088
1089

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
1090
1091
            batch.prefix_offsets,
            batch.read_offsets,
1092
1093
            batch.stopping_criterias,
            batch.all_input_ids,
1094
1095
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
1096
            batch.top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1097
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1098
1099
            batch_top_token_ids,
            batch_top_token_logprobs,
1100
1101
1102
        )

        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1103
        index = 0
1104
1105
1106
        for i, (
            request,
            input_length,
1107
1108
            prefix_offset,
            read_offset,
1109
1110
            stopping_criteria,
            all_input_ids,
1111
1112
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
1113
            top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1114
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1115
1116
            top_token_ids,
            top_token_logprobs,
1117
        ) in enumerate(iterator):
1118
            # Append next token to all tokens
Nicolas Patry's avatar
Nicolas Patry committed
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
            next_token_texts = []
            left = 0

            current_stopped = False
            for j in range(index, index + n_accepted_ids):
                # Generated token
                next_token_id = next_token_ids[j]
                all_input_ids.append(next_token_id)
                next_token_text, prefix_offset, read_offset = self.decode_token(
                    all_input_ids,
                    prefix_offset,
                    read_offset,
                )
                next_token_texts.append(next_token_text)
1133

Nicolas Patry's avatar
Nicolas Patry committed
1134
1135
1136
1137
                stop, reason = stopping_criteria(
                    next_token_id,
                    next_token_text,
                )
1138

Nicolas Patry's avatar
Nicolas Patry committed
1139
1140
1141
1142
1143
1144
1145
                if stop:
                    left = index + n_accepted_ids - j - 1
                    current_stopped = True
                    break
                else:
                    current_stopped = False
            stopped = stopped and current_stopped
1146

OlivierDehaene's avatar
OlivierDehaene committed
1147
1148
1149
1150
            _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
            _next_token_logprobs = next_token_logprobs[
                index : index + n_accepted_ids - left
            ]
Nicolas Patry's avatar
Nicolas Patry committed
1151
            index += n_accepted_ids
1152

1153
1154
1155
1156
1157
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
1158
1159
                    output_text, _, _ = self.decode_token(
                        all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1160
1161
1162
1163
1164
1165
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
1166
1167
                    )
                    generated_text = GeneratedText(
1168
1169
1170
1171
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
1172
1173
1174
1175
1176
                    )
                else:
                    generated_text = None

                # Prefill
1177
1178
1179
1180
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

1181
1182
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
1183
                        out_start_index : out_end_index - 1
1184
1185
1186
1187
1188
1189
1190
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1191
1192

                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
1193
1194
1195
1196
                        prefill_token_ids,
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
1197
1198
1199
1200
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
1201
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
1202
                    all_top_tokens = []
drbh's avatar
drbh committed
1203
                    for top_token_ids, top_token_logprobs in zip(
1204
1205
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
1206
1207
1208
1209
1210
1211
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
1212
1213
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
1214
1215
1216
1217
1218
1219
1220
1221
1222
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
1223
1224
1225
                else:
                    top_tokens = None

1226
1227
1228
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1229
1230
1231
1232
1233
1234
                    Tokens(
                        _next_token_ids,
                        _next_token_logprobs,
                        next_token_texts,
                        [nid in self.all_special_ids for nid in _next_token_ids],
                    ),
1235
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
1236
                    top_tokens,
1237
1238
                )

1239
                generations.append(generation)
1240

drbh's avatar
drbh committed
1241
1242
1243
            # accept each new token for this specific request since we may
            # have more than one new token per request with speculative decoding
            for next_token_id in _next_token_ids:
OlivierDehaene's avatar
OlivierDehaene committed
1244
1245
1246
                batch.next_token_chooser = (
                    batch.next_token_chooser.advance_grammar_single(i, next_token_id)
                )
drbh's avatar
drbh committed
1247

1248
            # Update values
1249
            batch.input_lengths[i] = input_length + n_accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1250
1251
            if batch.input_lengths[i] > batch.max_seqlen:
                batch.max_seqlen = batch.input_lengths[i]
1252
1253
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1254
1255
            batch.all_input_ids[i] = all_input_ids

1256
1257
1258
        if stopped:
            del batch
            # No need to return a batch if we know that all requests stopped
1259
1260
1261
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
1262

1263
1264
1265
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1266

1267
1268
1269
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)