flash_causal_lm.py 45.9 KB
Newer Older
1
import math
2
import os
3
import time
4
import itertools
5
6
7
import torch
import torch.distributed

8
9
import numpy as np

10
from loguru import logger
11
12
from dataclasses import dataclass
from opentelemetry import trace
13
from transformers import PreTrainedTokenizerBase
14
from typing import Optional, Tuple, List, Type, Dict
OlivierDehaene's avatar
OlivierDehaene committed
15
from text_generation_server.models import Model
16
from text_generation_server.utils.tokens import batch_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
17
from text_generation_server.utils.speculate import get_speculate
18
19
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
20
    Tokens,
21
22
23
    Generation,
    GeneratedText,
)
24
25
26
27
28
from text_generation_server.models.cache_manager import (
    get_cache_manager,
    set_cache_manager,
    BLOCK_SIZE,
)
29
from text_generation_server.pb import generate_pb2
30
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
31
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
32
from text_generation_server.utils.dist import MEMORY_FRACTION
33

Nicolas Patry's avatar
Nicolas Patry committed
34
from text_generation_server.utils.import_utils import (
Nicolas Patry's avatar
Nicolas Patry committed
35
36
37
    empty_cache,
    synchronize,
    get_free_memory,
Nicolas Patry's avatar
Nicolas Patry committed
38
39
)

Nicolas Patry's avatar
Nicolas Patry committed
40
41
tracer = trace.get_tracer(__name__)

42

43
44
45
46
@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
47
48
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
49
50

    # Decoder values
51
52
    input_ids: torch.Tensor
    position_ids: torch.Tensor
Nicolas Patry's avatar
Nicolas Patry committed
53
    speculative_ids: torch.Tensor
54

55
56
57
58
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
59
60
61
62
63
64
65
66
67
68
69
70
71
72

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor
    # List of tuple of ints representing the number of blocks and slots needed by each sequence
    needed_blocks_slots: Optional[List[Tuple[int, int]]]

    # Set in prefill by the CacheManager
    # list of length b of list of length s_i // block_size
    block_tables: Optional[List[List[int]]]
73
    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
74
75
76
77
    block_tables_tensor: Optional[torch.Tensor]
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
    slots: Optional[torch.Tensor]

78
79
    max_seqlen: int

80
81
82
83
84
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

85
86
    # All tokens
    all_input_ids: List[List[int]]
87
    all_input_ids_tensor: torch.Tensor
88
89
90

    # Lengths of all generations present in the batch
    input_lengths: List[int]
91
    input_lengths_tensor: torch.Tensor
92
93
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
94
95

    # Generation helpers
96
    next_token_chooser: HeterogeneousNextTokenChooser
97
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
98
99
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
100

101
102
103
104
    # Number of blocks in this batch
    blocks: int
    # Maximum number of blocks
    max_blocks: int
105

106
107
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
108
            id=self.batch_id,
109
            request_ids=[r.id for r in self.requests],
110
            size=len(self),
111
            max_tokens=self.blocks * BLOCK_SIZE,
112
113
114
        )

    @classmethod
115
    def batch_tokenized_inputs(cls, requests, tokenizer):
116
117
        batch_inputs = []
        max_truncation = 0
118
        for r in requests:
119
120
121
122
123
124
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]
125
        return batch_tokenized_inputs
126

127
128
129
130
131
132
133
134
135
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
drbh's avatar
drbh committed
136
137
138
139
140
141
142
143
144
145
146
        return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)

    @classmethod
    def from_tokenized(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        batch_tokenized_inputs,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
147
        position_ids = []
Nicolas Patry's avatar
Nicolas Patry committed
148
        speculative_ids = []
149
        cu_seqlen_prefill = [0]
150
151
152
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
153
154

        input_lengths = []
155
156
        prefix_offsets = []
        read_offsets = []
157
        all_input_ids = []
158
        requests_idx_mapping = {}
159

160
161
162
163
164
165
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

166
        next_token_chooser_parameters = []
167
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
168
        top_n_tokens = []
169
170
171

        # Cumulative length
        cumulative_length = 0
172
        cumulative_max_length = 0
173
        prefill_out_cumulative_length = 0
174

175
176
        blocks = 0
        max_seqlen = 0
177
        max_length = 0
178
        max_blocks = 0
179

180
        # Parse batch
181
182
183
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
184
185
186
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

187
            tokenized_input = tokenized_input[-r.truncate :]
188
189
190
191
192
            if (
                tokenized_input[0] == tokenizer.bos_token_id
                and tokenized_input[1] == tokenizer.bos_token_id
            ):
                tokenized_input = tokenized_input[1:]
193

194
195
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
196

197
            prefix_offsets.append(input_length - 5)
198
            read_offsets.append(input_length)
199

200
            all_input_ids.append(tokenized_input)
201
202

            # Position ids
203
204
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
205
206

            # Add cumulative lengths of all previous inputs
207
            cu_seqlen_prefill.append(cumulative_length + input_length)
208

209
            next_token_chooser_parameters.append(r.parameters)
210

211
212
213
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
214
            max_new_tokens = stopping_criteria.max_new_tokens
215
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
216
            top_n_tokens.append(r.top_n_tokens)
217

218
219
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
220
            speculative_length = get_speculate()
drbh's avatar
drbh committed
221
            speculative_length = 0 if speculative_length is None else speculative_length
Nicolas Patry's avatar
Nicolas Patry committed
222
            total_tokens = input_length + max_new_tokens - 1 + speculative_length
223
224
225
226
227
228
229
230
231
232
233
234
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
            blocks += needed_blocks
            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

255
256
            # Update
            cumulative_length += input_length
257
258
259
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
OlivierDehaene's avatar
OlivierDehaene committed
260
261
262
            max_length = max(
                max_length, input_length + max_new_tokens + speculative_length
            )
263
264

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
265
            next_token_chooser_parameters, dtype, device, tokenizer
266
        )
267
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
268
269
270
271
272
273
274

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
275

276
277
278
279
280
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

281
282
283
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
284
            slot_indices = torch.cat(slot_indices)
285
286
287
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
288
            slot_indices = slot_indices[0]
289

290
291
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
292
293
294
        )
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
295
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
296
297
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
298
        )
299

300
301
        if all_prefill_logprobs:
            prefill_head_indices = None
302
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
303
        elif no_prefill_logprobs:
304
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
305
306
307
308
309
310
311
312
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
Nicolas Patry's avatar
Nicolas Patry committed
313
314
315
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
316

317
318
319
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
320
            requests_idx_mapping=requests_idx_mapping,
321
322
            input_ids=input_ids,
            position_ids=position_ids,
323
            cu_seqlen_prefill=cu_seqlen_prefill,
324
325
326
327
328
329
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
330
            max_seqlen=max_seqlen,
331
332
333
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
334
            input_lengths=input_lengths,
335
            input_lengths_tensor=input_lengths_tensor,
336
337
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
338
            all_input_ids=all_input_ids,
339
340
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
341
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
342
343
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
344
345
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
346
            speculative_ids=None,
347
348
        )

349
    @tracer.start_as_current_span("filter")
350
351
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
352
353
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
354
        if len(request_ids) == len(self):
355
356
            return self

357
        device = self.input_ids.device
358

359
360
361
        # New values after filtering
        requests_idx_mapping = {}

362
363
364
        # Used to index into tensors
        indices = []

365
366
367
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
368
369
        )

370
        # Create on CPU to only move to GPU once instead of at every copy
371
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
372
373
        max_seqlen = 0

374
        requests = []
375
376
        start_slots = []
        block_tables = []
377
378
        all_input_ids = []

379
        input_lengths = []
380
381
        prefix_offsets = []
        read_offsets = []
382

383
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
384
        top_n_tokens = []
385

386
387
388
389
390
        blocks = 0
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

391
392
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
393
            indices.append(idx)
394
395
396
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
397
398
399
400

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
401

402
403
404
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
405
406
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
407

408
409
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
410

Nicolas Patry's avatar
Nicolas Patry committed
411
412
            top_n_tokens.append(self.top_n_tokens[idx])

413
            remaining_tokens = (
414
415
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
416

417
418
419
420
421
            request_block_table = self.block_tables[idx]
            blocks += len(request_block_table)
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

422
            # Copy to tensor (CPU)
423
            slot_indices[i] = cumulative_max_length + request_input_length - 1
424
425

            # Set slice
426
427
428
429
430
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
431
432
433
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
434

435
436
437
438
439
440
441
442
443
            max_blocks = max(max_blocks, len(request_block_table))

        block_indices_to_free = []
        # Iterate on all requests
        for i, r in enumerate(self.requests):
            # Filter requests that are not part of the new batch
            if r.id not in requests_idx_mapping.keys():
                block_indices_to_free.extend(self.block_tables[i])
        # Free blocks
444
        get_cache_manager().free(block_indices_to_free)
445
446
447
        # Needed to avoid dropping blocks when the batches will go out of scope
        self.block_tables = None

448
449
450
451
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
452
453
454
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
455
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
456
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
OlivierDehaene's avatar
OlivierDehaene committed
457
458
459
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
460
461

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
462

463
        # Move to GPU now that we have the whole tensor
464
        slot_indices = slot_indices.to(device)
465

466
        return type(self)(
467
468
469
470
471
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
472
            cu_seqlen_prefill=None,
473
474
475
476
477
478
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
479
            max_seqlen=max_seqlen,
480
481
482
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
483
            input_lengths=input_lengths,
484
            input_lengths_tensor=input_lengths_tensor,
485
486
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
487
488
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
489
            next_token_chooser=next_token_chooser,
490
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
491
492
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
493
494
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
495
            speculative_ids=speculative_ids,
496
497
498
499
500
501
502
503
504
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

505
506
507
508
509
510
511
512
513
514
        blocks = 0
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
            blocks += b.blocks
OlivierDehaene's avatar
OlivierDehaene committed
515
516
517
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
518
519
520
521
522
523
524
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
525
                    + speculative_length
526
527
528
529
530
531
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
532
533
534

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
535
536
537
538
539
540
541
542
543
544
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
545
        )
Nicolas Patry's avatar
Nicolas Patry committed
546
547
548
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
549

550
551
        start_slots = []
        block_tables = []
552
553
554
        all_input_ids = []

        input_lengths = []
555
556
        prefix_offsets = []
        read_offsets = []
557

558
        next_token_chooser_parameters = []
559
        fsm_grammar_states = []
560
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
561
        top_n_tokens = []
562

563
        # Cumulative length
564
        cumulative_batch_size = 0
565
        cumulative_slots = 0
566
567
568

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
569
570
571
572
573
574
575
576

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

577
578
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
579
580
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
581
582
583
584

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
585
586
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
587
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
588
            slots[slots_start_index:slots_end_index] = batch.slots
589

590
591
592
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
593

594
595
596
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
597

598
599
600
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
601
602
            all_input_ids.extend(batch.all_input_ids)

603
            input_lengths.extend(batch.input_lengths)
604
605
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
606

607
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
608
            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
609
610
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
611
612
            top_n_tokens.extend(batch.top_n_tokens)

613
            # Update
614
            cumulative_batch_size += len(batch)
615
            cumulative_slots += len(batch.slots)
616

617
        start_slots = torch.concat(start_slots)
618

619
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
620
621
622
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
drbh's avatar
drbh committed
623
            tokenizer=batches[0].next_token_chooser.tokenizer,
624
            fsm_grammar_states=fsm_grammar_states,
625
626
        )

OlivierDehaene's avatar
OlivierDehaene committed
627
628
629
630
631
        speculative_ids = (
            torch.cat([b.speculative_ids for b in batches], dim=0)
            if batches[0].speculative_ids is not None
            else None
        )
Nicolas Patry's avatar
Nicolas Patry committed
632

633
634
635
        # Needed to avoid dropping blocks when the batches will go out of scope
        for b in batches:
            b.block_tables = None
636
            del b
637

638
        return cls(
639
640
            batch_id=batches[0].batch_id,
            requests=requests,
641
            requests_idx_mapping=requests_idx_mapping,
642
643
            input_ids=input_ids,
            position_ids=position_ids,
644
            cu_seqlen_prefill=None,
645
646
647
648
649
650
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
651
            max_seqlen=max_seqlen,
652
653
654
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
655
            input_lengths=input_lengths,
656
            input_lengths_tensor=input_lengths_tensor,
657
658
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
659
660
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
661
            next_token_chooser=next_token_chooser,
662
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
663
664
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
665
666
            blocks=blocks,
            max_blocks=max_blocks,
OlivierDehaene's avatar
OlivierDehaene committed
667
            speculative_ids=speculative_ids,
668
669
        )

670
671
672
    def __del__(self):
        if self.block_tables is not None and self.block_tables:
            # Free blocks
673
674
675
            get_cache_manager().free(
                list(itertools.chain.from_iterable(self.block_tables))
            )
676

677
678
679
680
681
682
683
    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
684
685
686
687
688
689
690
691
692
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        num_layers: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
693
        sliding_window: Optional[int] = None,
694
    ):
695
696
697
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
698

699
700
        self.cuda_graphs = {}

701
        super(FlashCausalLM, self).__init__(
702
            model=model,
703
704
705
706
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
707
708
            rank=rank,
            world_size=world_size,
709
            sliding_window=sliding_window,
710
711
712
713
714
715
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

716
717
718
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
719
        slots = torch.arange(bs, dtype=torch.int64, device=self.device)
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
        input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
        block_tables = (
            torch.arange(max_bt, dtype=torch.int32, device=self.device)
            .repeat(bs)
            .reshape((bs, max_bt))
        )
        kv_cache = get_cache_manager().kv_cache

        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "kv_cache": kv_cache,
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths,
        }
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs]["graph"] = graph

        torch.cuda.synchronize()
        # Run once outside to warmup
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=None,
            kv_cache=kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
            max_s=max_s,
            lm_head_indices=None,
        )
        torch.cuda.synchronize()

        with torch.cuda.graph(graph, pool=MEM_POOL):
755
            logits, speculative_logits = self.model.forward(
756
757
758
759
760
761
762
763
764
765
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                lm_head_indices=None,
            )
766
767
            self.cuda_graphs[bs]["logits"] = logits
            self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
768
769
        torch.cuda.synchronize()

770
    def warmup(self, batch: FlashCausalLMBatch):
771
        # The warmup batch is the biggest batch we could ever receive
Nicolas Patry's avatar
Nicolas Patry committed
772
773
        empty_cache()

774
        try:
775
            cache_manager = set_cache_manager(
776
                batch.blocks,
777
778
779
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
780
                self.sliding_window is not None,
781
782
783
                self.dtype,
                self.device,
            )
784
785
            max_bt = batch.max_blocks
            max_s = max_bt * get_cache_manager().block_size
786
            _, batch, _ = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
787
        except torch.cuda.OutOfMemoryError as e:
788
            raise RuntimeError(
789
790
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
791
            ) from e
792

Nicolas Patry's avatar
Nicolas Patry committed
793
        synchronize(self.device)
794

795
796
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
797
798
799
800
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

Nicolas Patry's avatar
Nicolas Patry committed
801
        free_memory = get_free_memory(self.device, MEMORY_FRACTION)
802
803

        num_blocks = (
804
805
            # Leave 5% for some wiggle room
            int((free_memory * 0.95) // total_cache_size)
806
            # Add batch.blocks as we allocated it above, so it is included in the peak memory.
807
            + cache_manager.num_blocks
808
809
        )

810
        del batch
811
        del cache_manager
812

813
        set_cache_manager(
814
815
816
817
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
818
            self.sliding_window is not None,
819
820
821
822
            self.dtype,
            self.device,
        )

823
        if CUDA_GRAPHS:
824
            try:
825
                logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
826
                # Warmup cuda graphs
827
                for bs in CUDA_GRAPHS:
828
829
                    if self.speculate is None or self.speculate + 1 <= bs:
                        self.cuda_graph_warmup(bs, max_s, max_bt)
OlivierDehaene's avatar
OlivierDehaene committed
830
            except torch.cuda.OutOfMemoryError:
831
                logger.exception(f"Decode cuda graph warmup failed")
832
833
        else:
            logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
834

835
        return int(num_blocks * BLOCK_SIZE)
836

837
838
839
    def forward(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
840
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
841
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
842
843
844
845
846
847
848
849
850
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
851
852
853

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
854
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
855
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
856
857
858
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
859
860
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
861
862
863
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
864
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
OlivierDehaene's avatar
OlivierDehaene committed
865
866
867
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
868
869

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
870
871
872
873
874
875
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
876
877
878
879
880
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
881
882
883
884
885
886
887
888
889
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
890

891
        bs = input_ids.shape[0]
OlivierDehaene's avatar
OlivierDehaene committed
892
893
894
895
896
897
898
899
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None

        if cu_seqlen_prefill is not None or cuda_graph is None:
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
            return self.model.forward(
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=cu_seqlen_prefill,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                lm_head_indices=lm_head_indices,
            )

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
        cuda_graph["block_tables"][
            : block_tables.shape[0], : block_tables.shape[1]
        ] = block_tables
        cuda_graph["slots"].fill_(-1)
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths

        # Replay the graph
        cuda_graph["graph"].replay()
        # Slice output to the correct shape
927
928
929
930
931
932
933
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
934
935
936
937

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
938
939
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
940
        prefill = batch.cu_seqlen_prefill is not None
941
        prefill_logprobs = batch.prefill_next_token_indices is not None
942

943
944
        if batch.needed_blocks_slots:
            # Allocate blocks to this batch
945
946
947
948
949
950
951
952
953
954
            block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
                batch.needed_blocks_slots,
                batch.blocks,
                batch.max_blocks,
                batch.input_ids.device,
            )
            batch.needed_blocks_slots = None
            batch.block_tables = block_tables
            batch.block_tables_tensor = block_tables_tensor
            batch.slots = slots
955

956
        try:
957
            out, speculative_logits = self.forward(batch)
958
959
960
        except Exception as e:
            del batch
            raise e
961

962
963
        if prefill:
            next_token_logits = (
964
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
965
            )
Nicolas Patry's avatar
Nicolas Patry committed
966
967
            if speculative_logits is not None:
                speculative_logits = (
OlivierDehaene's avatar
OlivierDehaene committed
968
969
970
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
Nicolas Patry's avatar
Nicolas Patry committed
971
                )
972
973
974
        else:
            next_token_logits = out

Nicolas Patry's avatar
Nicolas Patry committed
975
        speculate = get_speculate()
OlivierDehaene's avatar
OlivierDehaene committed
976
977
978
979
980
981
982
983
984
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen],
            next_token_logits,
Nicolas Patry's avatar
Nicolas Patry committed
985
            speculate,
OlivierDehaene's avatar
OlivierDehaene committed
986
987
            batch.speculative_ids,
            speculative_logits,
988
989
        )

Nicolas Patry's avatar
Nicolas Patry committed
990
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
991
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
992
993
        )

994
        if prefill:
995
            if len(batch) > 1 and prefill_logprobs:
996
997
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
998
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
999
1000

            next_position_ids = batch.position_ids.new_empty(len(batch))
1001
1002
1003
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
1004
1005
1006
1007
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

1008
1009
1010
1011
1012
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
1013
        stopped = True
1014
1015

        # Zipped iterator
OlivierDehaene's avatar
OlivierDehaene committed
1016
        iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
1017

1018
1019
1020
1021
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

1022
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1023
        index = 0
OlivierDehaene's avatar
OlivierDehaene committed
1024
        for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
1025
            # Indexing metadata
1026
1027
1028
            start_index = cumulative_length
            end_index = cumulative_length + input_length

1029
            if prefill:
1030
1031
1032
1033
1034
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

1035
1036
1037
1038
1039
1040
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
1041
1042
                if prefill_logprobs:
                    if len(batch) > 1:
drbh's avatar
drbh committed
1043
1044
1045
                        prefill_tokens_indices[out_start_index : out_end_index - 1] = (
                            batch.input_ids[start_index + 1 : start_index + out_length]
                        )
1046
1047
1048
1049
1050
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
1051

Nicolas Patry's avatar
Nicolas Patry committed
1052
1053
1054
            for j in range(n_accepted_ids):
                batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
                index += 1
1055
1056
1057

            cumulative_length += input_length

drbh's avatar
drbh committed
1058
        # Update values
Nicolas Patry's avatar
Nicolas Patry committed
1059
1060
1061
1062
1063
        batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
        batch.speculative_ids = speculative_ids
        batch.position_ids = next_position_ids + accepted_ids
        batch.input_lengths_tensor += accepted_ids
        batch.slot_indices += accepted_ids
1064

1065
        if prefill and prefill_logprobs:
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
1076
        next_token_ids = next_input_ids.tolist()
1077
1078
        accepted_ids = accepted_ids.tolist()
        start_decode = time.time_ns()
1079
1080
1081
1082
1083

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
1084
1085
            batch.prefix_offsets,
            batch.read_offsets,
1086
1087
            batch.stopping_criterias,
            batch.all_input_ids,
1088
1089
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
1090
            batch.top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1091
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1092
1093
            batch_top_token_ids,
            batch_top_token_logprobs,
1094
1095
1096
        )

        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1097
        index = 0
1098
1099
1100
        for i, (
            request,
            input_length,
1101
1102
            prefix_offset,
            read_offset,
1103
1104
            stopping_criteria,
            all_input_ids,
1105
1106
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
1107
            top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1108
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1109
1110
            top_token_ids,
            top_token_logprobs,
1111
        ) in enumerate(iterator):
1112
            # Append next token to all tokens
Nicolas Patry's avatar
Nicolas Patry committed
1113
1114
1115
            next_token_texts = []
            left = 0

1116
            logger.debug(f"Accepted ids {n_accepted_ids}")
Nicolas Patry's avatar
Nicolas Patry committed
1117

Nicolas Patry's avatar
Nicolas Patry committed
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
            current_stopped = False
            for j in range(index, index + n_accepted_ids):
                # Generated token
                next_token_id = next_token_ids[j]
                all_input_ids.append(next_token_id)
                next_token_text, prefix_offset, read_offset = self.decode_token(
                    all_input_ids,
                    prefix_offset,
                    read_offset,
                )
                next_token_texts.append(next_token_text)
1129

Nicolas Patry's avatar
Nicolas Patry committed
1130
1131
1132
1133
                stop, reason = stopping_criteria(
                    next_token_id,
                    next_token_text,
                )
1134

Nicolas Patry's avatar
Nicolas Patry committed
1135
1136
1137
1138
1139
1140
1141
                if stop:
                    left = index + n_accepted_ids - j - 1
                    current_stopped = True
                    break
                else:
                    current_stopped = False
            stopped = stopped and current_stopped
1142

OlivierDehaene's avatar
OlivierDehaene committed
1143
1144
1145
1146
            _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
            _next_token_logprobs = next_token_logprobs[
                index : index + n_accepted_ids - left
            ]
Nicolas Patry's avatar
Nicolas Patry committed
1147
            index += n_accepted_ids
1148

1149
1150
1151
1152
1153
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
1154
1155
                    output_text, _, _ = self.decode_token(
                        all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1156
1157
1158
1159
1160
1161
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
1162
1163
                    )
                    generated_text = GeneratedText(
1164
1165
1166
1167
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
1168
1169
1170
1171
1172
                    )
                else:
                    generated_text = None

                # Prefill
1173
1174
1175
1176
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

1177
1178
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
1179
                        out_start_index : out_end_index - 1
1180
1181
1182
1183
1184
1185
1186
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1187
1188

                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
1189
1190
1191
1192
                        prefill_token_ids,
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
1193
1194
1195
1196
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
1197
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
1198
                    all_top_tokens = []
drbh's avatar
drbh committed
1199
                    for top_token_ids, top_token_logprobs in zip(
1200
1201
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
1202
1203
1204
1205
1206
1207
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
1208
1209
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
1210
1211
1212
1213
1214
1215
1216
1217
1218
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
1219
1220
1221
                else:
                    top_tokens = None

1222
1223
1224
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1225
1226
1227
1228
1229
1230
                    Tokens(
                        _next_token_ids,
                        _next_token_logprobs,
                        next_token_texts,
                        [nid in self.all_special_ids for nid in _next_token_ids],
                    ),
1231
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
1232
                    top_tokens,
1233
1234
                )

1235
                generations.append(generation)
1236

drbh's avatar
drbh committed
1237
1238
1239
            # accept each new token for this specific request since we may
            # have more than one new token per request with speculative decoding
            for next_token_id in _next_token_ids:
OlivierDehaene's avatar
OlivierDehaene committed
1240
1241
1242
                batch.next_token_chooser = (
                    batch.next_token_chooser.advance_grammar_single(i, next_token_id)
                )
drbh's avatar
drbh committed
1243

1244
            # Update values
1245
            batch.input_lengths[i] = input_length + n_accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1246
1247
            if batch.input_lengths[i] > batch.max_seqlen:
                batch.max_seqlen = batch.input_lengths[i]
1248
1249
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1250
1251
            batch.all_input_ids[i] = all_input_ids

1252
1253
1254
        if stopped:
            del batch
            # No need to return a batch if we know that all requests stopped
1255
1256
1257
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
1258

1259
1260
1261
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1262

1263
1264
1265
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)