"test/kaldi_compatibility_impl.py" did not exist on "9ceb96c0c5edbae4b5f5e2701f4bc702af96b087"
flash_causal_lm.py 46.2 KB
Newer Older
1
import math
2
import os
3
import time
4
import itertools
5
6
7
import torch
import torch.distributed

8
9
import numpy as np

10
from loguru import logger
11
12
from dataclasses import dataclass
from opentelemetry import trace
13
from transformers import PreTrainedTokenizerBase
14
from typing import Optional, Tuple, List, Type, Dict
15

OlivierDehaene's avatar
OlivierDehaene committed
16
from text_generation_server.models import Model
17
from text_generation_server.utils.tokens import batch_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
18
from text_generation_server.utils.speculate import get_speculate
19
20
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
21
    Tokens,
22
23
24
    Generation,
    GeneratedText,
)
25
26
27
28
29
from text_generation_server.models.cache_manager import (
    get_cache_manager,
    set_cache_manager,
    BLOCK_SIZE,
)
30
from text_generation_server.pb import generate_pb2
31
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
32
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
33
from text_generation_server.utils.dist import MEMORY_FRACTION
34
35

tracer = trace.get_tracer(__name__)
36
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
37

38
39
40
41
@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
42
43
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
44
45

    # Decoder values
46
47
    input_ids: torch.Tensor
    position_ids: torch.Tensor
Nicolas Patry's avatar
Nicolas Patry committed
48
    speculative_ids: torch.Tensor
49

50
51
52
53
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
54
55
56
57
58
59
60
61
62
63
64
65
66
67

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor
    # List of tuple of ints representing the number of blocks and slots needed by each sequence
    needed_blocks_slots: Optional[List[Tuple[int, int]]]

    # Set in prefill by the CacheManager
    # list of length b of list of length s_i // block_size
    block_tables: Optional[List[List[int]]]
68
    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
69
70
71
72
    block_tables_tensor: Optional[torch.Tensor]
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
    slots: Optional[torch.Tensor]

73
74
    max_seqlen: int

75
76
77
78
79
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

80
81
    # All tokens
    all_input_ids: List[List[int]]
82
    all_input_ids_tensor: torch.Tensor
83
84
85

    # Lengths of all generations present in the batch
    input_lengths: List[int]
86
    input_lengths_tensor: torch.Tensor
87
88
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
89
90

    # Generation helpers
91
    next_token_chooser: HeterogeneousNextTokenChooser
92
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
93
94
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
95

96
97
98
99
    # Number of blocks in this batch
    blocks: int
    # Maximum number of blocks
    max_blocks: int
100

101
102
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
103
            id=self.batch_id,
104
            request_ids=[r.id for r in self.requests],
105
            size=len(self),
106
            max_tokens=self.blocks * BLOCK_SIZE,
107
108
109
        )

    @classmethod
110
    def batch_tokenized_inputs(cls, requests, tokenizer):
111
112
        batch_inputs = []
        max_truncation = 0
113
        for r in requests:
114
115
116
117
118
119
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]
120
        return batch_tokenized_inputs
121

122
123
124
125
126
127
128
129
130
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
131
        position_ids = []
Nicolas Patry's avatar
Nicolas Patry committed
132
        speculative_ids = []
133
        cu_seqlen_prefill = [0]
134
135
136
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
137
138

        input_lengths = []
139
140
        prefix_offsets = []
        read_offsets = []
141
        all_input_ids = []
142
        requests_idx_mapping = {}
143

144
145
146
147
148
149
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

150
        next_token_chooser_parameters = []
151
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
152
        top_n_tokens = []
153
154
155

        # Cumulative length
        cumulative_length = 0
156
        cumulative_max_length = 0
157
        prefill_out_cumulative_length = 0
158

159
160
        blocks = 0
        max_seqlen = 0
161
        max_length = 0
162
        max_blocks = 0
163

164
        # Parse batch
165
166
167
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
168
169
170
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

171
            tokenized_input = tokenized_input[-r.truncate :]
172
173
174
175
176
            if (
                tokenized_input[0] == tokenizer.bos_token_id
                and tokenized_input[1] == tokenizer.bos_token_id
            ):
                tokenized_input = tokenized_input[1:]
177

178
179
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
180

181
            prefix_offsets.append(input_length - 5)
182
            read_offsets.append(input_length)
183

184
            all_input_ids.append(tokenized_input)
185
186

            # Position ids
187
188
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
189
190

            # Add cumulative lengths of all previous inputs
191
            cu_seqlen_prefill.append(cumulative_length + input_length)
192

193
            next_token_chooser_parameters.append(r.parameters)
194

195
196
197
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
198
            max_new_tokens = stopping_criteria.max_new_tokens
199
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
200
            top_n_tokens.append(r.top_n_tokens)
201

202
203
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
204
205
            speculative_length = get_speculate()
            total_tokens = input_length + max_new_tokens - 1 + speculative_length
206
207
208
209
210
211
212
213
214
215
216
217
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
            blocks += needed_blocks
            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

238
239
            # Update
            cumulative_length += input_length
240
241
242
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
OlivierDehaene's avatar
OlivierDehaene committed
243
244
245
            max_length = max(
                max_length, input_length + max_new_tokens + speculative_length
            )
246
247

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
248
            next_token_chooser_parameters, dtype, device, tokenizer
249
        )
250
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
251
252
253
254
255
256
257

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
258

259
260
261
262
263
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

264
265
266
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
267
            slot_indices = torch.cat(slot_indices)
268
269
270
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
271
            slot_indices = slot_indices[0]
272

273
274
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
275
276
277
        )
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
278
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
279
280
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
281
        )
282

283
284
        if all_prefill_logprobs:
            prefill_head_indices = None
285
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
286
        elif no_prefill_logprobs:
287
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
288
289
290
291
292
293
294
295
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
Nicolas Patry's avatar
Nicolas Patry committed
296
297
298
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
299

300
301
302
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
303
            requests_idx_mapping=requests_idx_mapping,
304
305
            input_ids=input_ids,
            position_ids=position_ids,
306
            cu_seqlen_prefill=cu_seqlen_prefill,
307
308
309
310
311
312
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
313
            max_seqlen=max_seqlen,
314
315
316
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
317
            input_lengths=input_lengths,
318
            input_lengths_tensor=input_lengths_tensor,
319
320
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
321
            all_input_ids=all_input_ids,
322
323
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
324
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
325
326
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
327
328
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
329
            speculative_ids=None,
330
331
        )

332
    @tracer.start_as_current_span("filter")
333
334
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
335
336
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
337
        if len(request_ids) == len(self):
338
339
            return self

340
        device = self.input_ids.device
341

342
343
344
        # New values after filtering
        requests_idx_mapping = {}

345
346
347
        # Used to index into tensors
        indices = []

348
349
350
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
351
352
        )

353
        # Create on CPU to only move to GPU once instead of at every copy
354
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
355
356
        max_seqlen = 0

357
        requests = []
358
359
        start_slots = []
        block_tables = []
360
361
        all_input_ids = []

362
        input_lengths = []
363
364
        prefix_offsets = []
        read_offsets = []
365

366
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
367
        top_n_tokens = []
368

369
370
371
372
373
        blocks = 0
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

374
375
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
376
            indices.append(idx)
377
378
379
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
380
381
382
383

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
384

385
386
387
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
388
389
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
390

391
392
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
393

Nicolas Patry's avatar
Nicolas Patry committed
394
395
            top_n_tokens.append(self.top_n_tokens[idx])

396
            remaining_tokens = (
397
398
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
399

400
401
402
403
404
            request_block_table = self.block_tables[idx]
            blocks += len(request_block_table)
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

405
            # Copy to tensor (CPU)
406
            slot_indices[i] = cumulative_max_length + request_input_length - 1
407
408

            # Set slice
409
410
411
412
413
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
414
415
416
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
417

418
419
420
421
422
423
424
425
426
            max_blocks = max(max_blocks, len(request_block_table))

        block_indices_to_free = []
        # Iterate on all requests
        for i, r in enumerate(self.requests):
            # Filter requests that are not part of the new batch
            if r.id not in requests_idx_mapping.keys():
                block_indices_to_free.extend(self.block_tables[i])
        # Free blocks
427
        get_cache_manager().free(block_indices_to_free)
428
429
430
        # Needed to avoid dropping blocks when the batches will go out of scope
        self.block_tables = None

431
432
433
434
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
435
436
437
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
438
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
439
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
OlivierDehaene's avatar
OlivierDehaene committed
440
441
442
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
443
444

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
445

446
        # Move to GPU now that we have the whole tensor
447
        slot_indices = slot_indices.to(device)
448

449
        return type(self)(
450
451
452
453
454
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
455
            cu_seqlen_prefill=None,
456
457
458
459
460
461
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
462
            max_seqlen=max_seqlen,
463
464
465
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
466
            input_lengths=input_lengths,
467
            input_lengths_tensor=input_lengths_tensor,
468
469
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
470
471
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
472
            next_token_chooser=next_token_chooser,
473
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
474
475
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
476
477
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
478
            speculative_ids=speculative_ids,
479
480
481
482
483
484
485
486
487
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

488
489
490
491
492
493
494
495
496
497
        blocks = 0
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
            blocks += b.blocks
OlivierDehaene's avatar
OlivierDehaene committed
498
499
500
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
501
502
503
504
505
506
507
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
508
                    + speculative_length
509
510
511
512
513
514
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
515
516
517

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
518
519
520
521
522
523
524
525
526
527
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
528
        )
Nicolas Patry's avatar
Nicolas Patry committed
529
530
531
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
532

533
534
        start_slots = []
        block_tables = []
535
536
537
        all_input_ids = []

        input_lengths = []
538
539
        prefix_offsets = []
        read_offsets = []
540

541
        next_token_chooser_parameters = []
542
        fsm_grammar_states = []
543
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
544
        top_n_tokens = []
545

546
        # Cumulative length
547
        cumulative_batch_size = 0
548
        cumulative_slots = 0
549
550
551

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
552
553
554
555
556
557
558
559

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

560
561
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
562
563
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
564
565
566
567

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
568
569
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
570
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
571
            slots[slots_start_index:slots_end_index] = batch.slots
572

573
574
575
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
576

577
578
579
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
580

581
582
583
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
584
585
            all_input_ids.extend(batch.all_input_ids)

586
            input_lengths.extend(batch.input_lengths)
587
588
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
589

590
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
591
            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
592
593
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
594
595
            top_n_tokens.extend(batch.top_n_tokens)

596
            # Update
597
            cumulative_batch_size += len(batch)
598
            cumulative_slots += len(batch.slots)
599

600
        start_slots = torch.concat(start_slots)
601

602
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
603
604
605
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
drbh's avatar
drbh committed
606
            tokenizer=batches[0].next_token_chooser.tokenizer,
607
            fsm_grammar_states=fsm_grammar_states,
608
609
        )

OlivierDehaene's avatar
OlivierDehaene committed
610
611
612
613
614
        speculative_ids = (
            torch.cat([b.speculative_ids for b in batches], dim=0)
            if batches[0].speculative_ids is not None
            else None
        )
Nicolas Patry's avatar
Nicolas Patry committed
615

616
617
618
        # Needed to avoid dropping blocks when the batches will go out of scope
        for b in batches:
            b.block_tables = None
619
            del b
620

621
        return cls(
622
623
            batch_id=batches[0].batch_id,
            requests=requests,
624
            requests_idx_mapping=requests_idx_mapping,
625
626
            input_ids=input_ids,
            position_ids=position_ids,
627
            cu_seqlen_prefill=None,
628
629
630
631
632
633
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
634
            max_seqlen=max_seqlen,
635
636
637
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
638
            input_lengths=input_lengths,
639
            input_lengths_tensor=input_lengths_tensor,
640
641
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
642
643
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
644
            next_token_chooser=next_token_chooser,
645
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
646
647
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
648
649
            blocks=blocks,
            max_blocks=max_blocks,
OlivierDehaene's avatar
OlivierDehaene committed
650
            speculative_ids=speculative_ids,
651
652
        )

653
654
655
    def __del__(self):
        if self.block_tables is not None and self.block_tables:
            # Free blocks
656
657
658
            get_cache_manager().free(
                list(itertools.chain.from_iterable(self.block_tables))
            )
659

660
661
662
663
664
665
666
    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
667
668
669
670
671
672
673
674
675
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        num_layers: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
676
        sliding_window: Optional[int] = None,
677
    ):
678
679
680
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
681

682
683
        self.cuda_graphs = {}

684
        super(FlashCausalLM, self).__init__(
685
            model=model,
686
687
688
689
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
690
691
            rank=rank,
            world_size=world_size,
692
            sliding_window=sliding_window,
693
694
695
696
697
698
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

699
700
701
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
702
        slots = torch.arange(bs, dtype=torch.int64, device=self.device)
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
        input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
        block_tables = (
            torch.arange(max_bt, dtype=torch.int32, device=self.device)
            .repeat(bs)
            .reshape((bs, max_bt))
        )
        kv_cache = get_cache_manager().kv_cache

        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "kv_cache": kv_cache,
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths,
        }
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs]["graph"] = graph

        torch.cuda.synchronize()
        # Run once outside to warmup
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=None,
            kv_cache=kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
            max_s=max_s,
            lm_head_indices=None,
        )
        torch.cuda.synchronize()

        with torch.cuda.graph(graph, pool=MEM_POOL):
738
            logits, speculative_logits = self.model.forward(
739
740
741
742
743
744
745
746
747
748
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                lm_head_indices=None,
            )
749
750
            self.cuda_graphs[bs]["logits"] = logits
            self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
751
752
        torch.cuda.synchronize()

753
    def warmup(self, batch: FlashCausalLMBatch):
754
        # The warmup batch is the biggest batch we could ever receive
755
756
757
758
        if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
            torch.cuda.empty_cache()
        elif IS_XPU_SYSTEM:
            torch.xpu.empty_cache()
759
        try:
760
            cache_manager = set_cache_manager(
761
                batch.blocks,
762
763
764
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
765
                self.sliding_window is not None,
766
767
768
                self.dtype,
                self.device,
            )
769
770
            max_bt = batch.max_blocks
            max_s = max_bt * get_cache_manager().block_size
771
            _, batch, _ = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
772
        except torch.cuda.OutOfMemoryError as e:
773
            raise RuntimeError(
774
775
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
776
            ) from e
777

778
779
780
781
        if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
            torch.cuda.synchronize(self.device)
        elif IS_XPU_SYSTEM:
            torch.xpu.synchronize(self.device)
782

783
784
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
785
786
787
788
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

789
790
791
        if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
            total_free_memory, _ = torch.cuda.mem_get_info(self.device)
            total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
792

793
794
795
796
797
798
799
800
            free_memory = max(
                0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
            )
        elif IS_XPU_SYSTEM:
            total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
            free_memory = int(total_gpu_memory *0.5)
        else:
            raise NotImplementedError("FlashModel is only available on GPU")
801
802

        num_blocks = (
803
804
            # Leave 5% for some wiggle room
            int((free_memory * 0.95) // total_cache_size)
805
            # Add batch.blocks as we allocated it above, so it is included in the peak memory.
806
            + cache_manager.num_blocks
807
808
        )

809
        del batch
810
        del cache_manager
811

812
        set_cache_manager(
813
814
815
816
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
817
            self.sliding_window is not None,
818
819
820
821
            self.dtype,
            self.device,
        )

822
        if CUDA_GRAPHS:
823
            try:
824
                logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
825
                # Warmup cuda graphs
826
                for bs in CUDA_GRAPHS:
827
828
                    if self.speculate is None or self.speculate + 1 <= bs:
                        self.cuda_graph_warmup(bs, max_s, max_bt)
OlivierDehaene's avatar
OlivierDehaene committed
829
            except torch.cuda.OutOfMemoryError:
830
                logger.exception(f"Decode cuda graph warmup failed")
831
832
        else:
            logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
833

834
        return int(num_blocks * BLOCK_SIZE)
835

836
837
838
    def forward(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
839
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
840
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
841
842
843
844
845
846
847
848
849
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
850
851
852

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
853
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
854
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
855
856
857
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
858
859
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
860
861
862
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
863
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
OlivierDehaene's avatar
OlivierDehaene committed
864
865
866
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
867
868

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
869
870
871
872
873
874
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
875
876
877
878
879
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
880
881
882
883
884
885
886
887
888
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
889

890
        bs = input_ids.shape[0]
OlivierDehaene's avatar
OlivierDehaene committed
891
892
893
894
895
896
897
898
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None

        if cu_seqlen_prefill is not None or cuda_graph is None:
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
            return self.model.forward(
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=cu_seqlen_prefill,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                lm_head_indices=lm_head_indices,
            )

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
        cuda_graph["block_tables"][
            : block_tables.shape[0], : block_tables.shape[1]
        ] = block_tables
        cuda_graph["slots"].fill_(-1)
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths

        # Replay the graph
        cuda_graph["graph"].replay()
        # Slice output to the correct shape
926
927
928
929
930
931
932
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
933
934
935
936

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
937
938
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
939
        prefill = batch.cu_seqlen_prefill is not None
940
        prefill_logprobs = batch.prefill_next_token_indices is not None
941

942
943
        if batch.needed_blocks_slots:
            # Allocate blocks to this batch
944
945
946
947
948
949
950
951
952
953
            block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
                batch.needed_blocks_slots,
                batch.blocks,
                batch.max_blocks,
                batch.input_ids.device,
            )
            batch.needed_blocks_slots = None
            batch.block_tables = block_tables
            batch.block_tables_tensor = block_tables_tensor
            batch.slots = slots
954

955
        try:
956
            out, speculative_logits = self.forward(batch)
957
958
959
        except Exception as e:
            del batch
            raise e
960

961
962
        if prefill:
            next_token_logits = (
963
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
964
            )
Nicolas Patry's avatar
Nicolas Patry committed
965
966
            if speculative_logits is not None:
                speculative_logits = (
OlivierDehaene's avatar
OlivierDehaene committed
967
968
969
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
Nicolas Patry's avatar
Nicolas Patry committed
970
                )
971
972
973
        else:
            next_token_logits = out

Nicolas Patry's avatar
Nicolas Patry committed
974
        speculate = get_speculate()
OlivierDehaene's avatar
OlivierDehaene committed
975
976
977
978
979
980
981
982
983
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen],
            next_token_logits,
Nicolas Patry's avatar
Nicolas Patry committed
984
            speculate,
OlivierDehaene's avatar
OlivierDehaene committed
985
986
            batch.speculative_ids,
            speculative_logits,
987
988
        )

Nicolas Patry's avatar
Nicolas Patry committed
989
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
990
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
991
992
        )

993
        if prefill:
994
            if len(batch) > 1 and prefill_logprobs:
995
996
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
997
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
998
999

            next_position_ids = batch.position_ids.new_empty(len(batch))
1000
1001
1002
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
1003
1004
1005
1006
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

1007
1008
1009
1010
1011
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
1012
        stopped = True
1013
1014

        # Zipped iterator
OlivierDehaene's avatar
OlivierDehaene committed
1015
        iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
1016

1017
1018
1019
1020
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

1021
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1022
        index = 0
OlivierDehaene's avatar
OlivierDehaene committed
1023
        for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
1024
            # Indexing metadata
1025
1026
1027
            start_index = cumulative_length
            end_index = cumulative_length + input_length

1028
            if prefill:
1029
1030
1031
1032
1033
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

1034
1035
1036
1037
1038
1039
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
1040
1041
                if prefill_logprobs:
                    if len(batch) > 1:
drbh's avatar
drbh committed
1042
1043
1044
                        prefill_tokens_indices[out_start_index : out_end_index - 1] = (
                            batch.input_ids[start_index + 1 : start_index + out_length]
                        )
1045
1046
1047
1048
1049
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
1050

Nicolas Patry's avatar
Nicolas Patry committed
1051
1052
1053
            for j in range(n_accepted_ids):
                batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
                index += 1
1054
1055
1056

            cumulative_length += input_length

drbh's avatar
drbh committed
1057
        # Update values
Nicolas Patry's avatar
Nicolas Patry committed
1058
1059
1060
1061
1062
        batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
        batch.speculative_ids = speculative_ids
        batch.position_ids = next_position_ids + accepted_ids
        batch.input_lengths_tensor += accepted_ids
        batch.slot_indices += accepted_ids
1063

1064
        if prefill and prefill_logprobs:
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
1075
        next_token_ids = next_input_ids.tolist()
1076
1077
        accepted_ids = accepted_ids.tolist()
        start_decode = time.time_ns()
1078
1079
1080
1081
1082

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
1083
1084
            batch.prefix_offsets,
            batch.read_offsets,
1085
1086
            batch.stopping_criterias,
            batch.all_input_ids,
1087
1088
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
1089
            batch.top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1090
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1091
1092
            batch_top_token_ids,
            batch_top_token_logprobs,
1093
1094
1095
        )

        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1096
        index = 0
1097
1098
1099
        for i, (
            request,
            input_length,
1100
1101
            prefix_offset,
            read_offset,
1102
1103
            stopping_criteria,
            all_input_ids,
1104
1105
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
1106
            top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1107
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1108
1109
            top_token_ids,
            top_token_logprobs,
1110
        ) in enumerate(iterator):
1111
            # Append next token to all tokens
Nicolas Patry's avatar
Nicolas Patry committed
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
            next_token_texts = []
            left = 0

            current_stopped = False
            for j in range(index, index + n_accepted_ids):
                # Generated token
                next_token_id = next_token_ids[j]
                all_input_ids.append(next_token_id)
                next_token_text, prefix_offset, read_offset = self.decode_token(
                    all_input_ids,
                    prefix_offset,
                    read_offset,
                )
                next_token_texts.append(next_token_text)
1126

Nicolas Patry's avatar
Nicolas Patry committed
1127
1128
1129
1130
                stop, reason = stopping_criteria(
                    next_token_id,
                    next_token_text,
                )
1131

Nicolas Patry's avatar
Nicolas Patry committed
1132
1133
1134
1135
1136
1137
1138
                if stop:
                    left = index + n_accepted_ids - j - 1
                    current_stopped = True
                    break
                else:
                    current_stopped = False
            stopped = stopped and current_stopped
1139

OlivierDehaene's avatar
OlivierDehaene committed
1140
1141
1142
1143
            _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
            _next_token_logprobs = next_token_logprobs[
                index : index + n_accepted_ids - left
            ]
Nicolas Patry's avatar
Nicolas Patry committed
1144
            index += n_accepted_ids
1145

1146
1147
1148
1149
1150
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
1151
1152
                    output_text, _, _ = self.decode_token(
                        all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1153
1154
1155
1156
1157
1158
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
1159
1160
                    )
                    generated_text = GeneratedText(
1161
1162
1163
1164
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
1165
1166
1167
1168
1169
                    )
                else:
                    generated_text = None

                # Prefill
1170
1171
1172
1173
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

1174
1175
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
1176
                        out_start_index : out_end_index - 1
1177
1178
1179
1180
1181
1182
1183
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1184
1185

                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
1186
1187
1188
1189
                        prefill_token_ids,
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
1190
1191
1192
1193
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
1194
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
1195
                    all_top_tokens = []
drbh's avatar
drbh committed
1196
                    for top_token_ids, top_token_logprobs in zip(
1197
1198
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
1199
1200
1201
1202
1203
1204
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
1205
1206
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
1207
1208
1209
1210
1211
1212
1213
1214
1215
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
1216
1217
1218
                else:
                    top_tokens = None

1219
1220
1221
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1222
1223
1224
1225
1226
1227
                    Tokens(
                        _next_token_ids,
                        _next_token_logprobs,
                        next_token_texts,
                        [nid in self.all_special_ids for nid in _next_token_ids],
                    ),
1228
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
1229
                    top_tokens,
1230
1231
                )

1232
                generations.append(generation)
1233

drbh's avatar
drbh committed
1234
1235
1236
            # accept each new token for this specific request since we may
            # have more than one new token per request with speculative decoding
            for next_token_id in _next_token_ids:
OlivierDehaene's avatar
OlivierDehaene committed
1237
1238
1239
                batch.next_token_chooser = (
                    batch.next_token_chooser.advance_grammar_single(i, next_token_id)
                )
drbh's avatar
drbh committed
1240

1241
            # Update values
1242
            batch.input_lengths[i] = input_length + n_accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1243
1244
            if batch.input_lengths[i] > batch.max_seqlen:
                batch.max_seqlen = batch.input_lengths[i]
1245
1246
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1247
1248
            batch.all_input_ids[i] = all_input_ids

1249
1250
1251
        if stopped:
            del batch
            # No need to return a batch if we know that all requests stopped
1252
1253
1254
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
1255

1256
1257
1258
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1259

1260
1261
1262
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)