flash_causal_lm.py 46.3 KB
Newer Older
1
import math
2
import os
3
import time
4
import itertools
5
6
7
import torch
import torch.distributed

8
9
import numpy as np

10
from loguru import logger
11
12
from dataclasses import dataclass
from opentelemetry import trace
13
from transformers import PreTrainedTokenizerBase
14
from typing import Optional, Tuple, List, Type, Dict
15

OlivierDehaene's avatar
OlivierDehaene committed
16
from text_generation_server.models import Model
17
from text_generation_server.utils.tokens import batch_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
18
from text_generation_server.utils.speculate import get_speculate
19
20
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
21
    Tokens,
22
23
24
    Generation,
    GeneratedText,
)
25
26
27
28
29
from text_generation_server.models.cache_manager import (
    get_cache_manager,
    set_cache_manager,
    BLOCK_SIZE,
)
30
from text_generation_server.pb import generate_pb2
31
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
32
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
33
from text_generation_server.utils.dist import MEMORY_FRACTION
34
35

tracer = trace.get_tracer(__name__)
Nicolas Patry's avatar
Nicolas Patry committed
36
37
38
39
40
from text_generation_server.utils.import_utils import (
    IS_CUDA_SYSTEM,
    IS_ROCM_SYSTEM,
    IS_XPU_SYSTEM,
)
41

42

43
44
45
46
@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
47
48
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
49
50

    # Decoder values
51
52
    input_ids: torch.Tensor
    position_ids: torch.Tensor
Nicolas Patry's avatar
Nicolas Patry committed
53
    speculative_ids: torch.Tensor
54

55
56
57
58
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
59
60
61
62
63
64
65
66
67
68
69
70
71
72

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor
    # List of tuple of ints representing the number of blocks and slots needed by each sequence
    needed_blocks_slots: Optional[List[Tuple[int, int]]]

    # Set in prefill by the CacheManager
    # list of length b of list of length s_i // block_size
    block_tables: Optional[List[List[int]]]
73
    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
74
75
76
77
    block_tables_tensor: Optional[torch.Tensor]
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
    slots: Optional[torch.Tensor]

78
79
    max_seqlen: int

80
81
82
83
84
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

85
86
    # All tokens
    all_input_ids: List[List[int]]
87
    all_input_ids_tensor: torch.Tensor
88
89
90

    # Lengths of all generations present in the batch
    input_lengths: List[int]
91
    input_lengths_tensor: torch.Tensor
92
93
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
94
95

    # Generation helpers
96
    next_token_chooser: HeterogeneousNextTokenChooser
97
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
98
99
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
100

101
102
103
104
    # Number of blocks in this batch
    blocks: int
    # Maximum number of blocks
    max_blocks: int
105

106
107
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
108
            id=self.batch_id,
109
            request_ids=[r.id for r in self.requests],
110
            size=len(self),
111
            max_tokens=self.blocks * BLOCK_SIZE,
112
113
114
        )

    @classmethod
115
    def batch_tokenized_inputs(cls, requests, tokenizer):
116
117
        batch_inputs = []
        max_truncation = 0
118
        for r in requests:
119
120
121
122
123
124
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]
125
        return batch_tokenized_inputs
126

127
128
129
130
131
132
133
134
135
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
136
        position_ids = []
Nicolas Patry's avatar
Nicolas Patry committed
137
        speculative_ids = []
138
        cu_seqlen_prefill = [0]
139
140
141
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
142
143

        input_lengths = []
144
145
        prefix_offsets = []
        read_offsets = []
146
        all_input_ids = []
147
        requests_idx_mapping = {}
148

149
150
151
152
153
154
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

155
        next_token_chooser_parameters = []
156
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
157
        top_n_tokens = []
158
159
160

        # Cumulative length
        cumulative_length = 0
161
        cumulative_max_length = 0
162
        prefill_out_cumulative_length = 0
163

164
165
        blocks = 0
        max_seqlen = 0
166
        max_length = 0
167
        max_blocks = 0
168

169
        # Parse batch
170
171
172
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
173
174
175
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

176
            tokenized_input = tokenized_input[-r.truncate :]
huangwb's avatar
huangwb committed
177
            '''fix input s=1 crash bug
178
179
180
181
182
            if (
                tokenized_input[0] == tokenizer.bos_token_id
                and tokenized_input[1] == tokenizer.bos_token_id
            ):
                tokenized_input = tokenized_input[1:]
huangwb's avatar
huangwb committed
183
            '''
184

185
186
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
187

188
            prefix_offsets.append(input_length - 5)
189
            read_offsets.append(input_length)
190

191
            all_input_ids.append(tokenized_input)
192
193

            # Position ids
194
195
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
196
197

            # Add cumulative lengths of all previous inputs
198
            cu_seqlen_prefill.append(cumulative_length + input_length)
199

200
            next_token_chooser_parameters.append(r.parameters)
201

202
203
204
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
205
            max_new_tokens = stopping_criteria.max_new_tokens
206
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
207
            top_n_tokens.append(r.top_n_tokens)
208

209
210
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
211
212
            speculative_length = get_speculate()
            total_tokens = input_length + max_new_tokens - 1 + speculative_length
213
214
215
216
217
218
219
220
221
222
223
224
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
            blocks += needed_blocks
            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

245
246
            # Update
            cumulative_length += input_length
247
248
249
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
OlivierDehaene's avatar
OlivierDehaene committed
250
251
252
            max_length = max(
                max_length, input_length + max_new_tokens + speculative_length
            )
253
254

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
255
            next_token_chooser_parameters, dtype, device, tokenizer
256
        )
257
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
258
259
260
261
262
263
264

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
265

266
267
268
269
270
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

271
272
273
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
274
            slot_indices = torch.cat(slot_indices)
275
276
277
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
278
            slot_indices = slot_indices[0]
279

280
281
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
282
283
284
        )
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
285
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
286
287
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
288
        )
289

290
291
        if all_prefill_logprobs:
            prefill_head_indices = None
292
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
293
        elif no_prefill_logprobs:
294
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
295
296
297
298
299
300
301
302
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
Nicolas Patry's avatar
Nicolas Patry committed
303
304
305
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
306

307
308
309
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
310
            requests_idx_mapping=requests_idx_mapping,
311
312
            input_ids=input_ids,
            position_ids=position_ids,
313
            cu_seqlen_prefill=cu_seqlen_prefill,
314
315
316
317
318
319
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
320
            max_seqlen=max_seqlen,
321
322
323
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
324
            input_lengths=input_lengths,
325
            input_lengths_tensor=input_lengths_tensor,
326
327
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
328
            all_input_ids=all_input_ids,
329
330
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
331
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
332
333
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
334
335
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
336
            speculative_ids=None,
337
338
        )

339
    @tracer.start_as_current_span("filter")
340
341
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
342
343
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
344
        if len(request_ids) == len(self):
345
346
            return self

347
        device = self.input_ids.device
348

349
350
351
        # New values after filtering
        requests_idx_mapping = {}

352
353
354
        # Used to index into tensors
        indices = []

355
356
357
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
358
359
        )

360
        # Create on CPU to only move to GPU once instead of at every copy
361
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
362
363
        max_seqlen = 0

364
        requests = []
365
366
        start_slots = []
        block_tables = []
367
368
        all_input_ids = []

369
        input_lengths = []
370
371
        prefix_offsets = []
        read_offsets = []
372

373
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
374
        top_n_tokens = []
375

376
377
378
379
380
        blocks = 0
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

381
382
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
383
            indices.append(idx)
384
385
386
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
387
388
389
390

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
391

392
393
394
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
395
396
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
397

398
399
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
400

Nicolas Patry's avatar
Nicolas Patry committed
401
402
            top_n_tokens.append(self.top_n_tokens[idx])

403
            remaining_tokens = (
404
405
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
406

407
408
409
410
411
            request_block_table = self.block_tables[idx]
            blocks += len(request_block_table)
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

412
            # Copy to tensor (CPU)
413
            slot_indices[i] = cumulative_max_length + request_input_length - 1
414
415

            # Set slice
416
417
418
419
420
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
421
422
423
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
424

425
426
427
428
429
430
431
432
433
            max_blocks = max(max_blocks, len(request_block_table))

        block_indices_to_free = []
        # Iterate on all requests
        for i, r in enumerate(self.requests):
            # Filter requests that are not part of the new batch
            if r.id not in requests_idx_mapping.keys():
                block_indices_to_free.extend(self.block_tables[i])
        # Free blocks
434
        get_cache_manager().free(block_indices_to_free)
435
436
437
        # Needed to avoid dropping blocks when the batches will go out of scope
        self.block_tables = None

438
439
440
441
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
442
443
444
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
445
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
446
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
OlivierDehaene's avatar
OlivierDehaene committed
447
448
449
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
450
451

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
452

453
        # Move to GPU now that we have the whole tensor
454
        slot_indices = slot_indices.to(device)
455

456
        return type(self)(
457
458
459
460
461
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
462
            cu_seqlen_prefill=None,
463
464
465
466
467
468
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
469
            max_seqlen=max_seqlen,
470
471
472
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
473
            input_lengths=input_lengths,
474
            input_lengths_tensor=input_lengths_tensor,
475
476
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
477
478
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
479
            next_token_chooser=next_token_chooser,
480
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
481
482
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
483
484
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
485
            speculative_ids=speculative_ids,
486
487
488
489
490
491
492
493
494
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

495
496
497
498
499
500
501
502
503
504
        blocks = 0
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
            blocks += b.blocks
OlivierDehaene's avatar
OlivierDehaene committed
505
506
507
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
508
509
510
511
512
513
514
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
515
                    + speculative_length
516
517
518
519
520
521
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
522
523
524

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
525
526
527
528
529
530
531
532
533
534
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
535
        )
Nicolas Patry's avatar
Nicolas Patry committed
536
537
538
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
539

540
541
        start_slots = []
        block_tables = []
542
543
544
        all_input_ids = []

        input_lengths = []
545
546
        prefix_offsets = []
        read_offsets = []
547

548
        next_token_chooser_parameters = []
549
        fsm_grammar_states = []
550
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
551
        top_n_tokens = []
552

553
        # Cumulative length
554
        cumulative_batch_size = 0
555
        cumulative_slots = 0
556
557
558

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
559
560
561
562
563
564
565
566

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

567
568
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
569
570
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
571
572
573
574

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
575
576
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
577
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
578
            slots[slots_start_index:slots_end_index] = batch.slots
579

580
581
582
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
583

584
585
586
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
587

588
589
590
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
591
592
            all_input_ids.extend(batch.all_input_ids)

593
            input_lengths.extend(batch.input_lengths)
594
595
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
596

597
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
598
            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
599
600
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
601
602
            top_n_tokens.extend(batch.top_n_tokens)

603
            # Update
604
            cumulative_batch_size += len(batch)
605
            cumulative_slots += len(batch.slots)
606

607
        start_slots = torch.concat(start_slots)
608

609
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
610
611
612
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
drbh's avatar
drbh committed
613
            tokenizer=batches[0].next_token_chooser.tokenizer,
614
            fsm_grammar_states=fsm_grammar_states,
615
616
        )

OlivierDehaene's avatar
OlivierDehaene committed
617
618
619
620
621
        speculative_ids = (
            torch.cat([b.speculative_ids for b in batches], dim=0)
            if batches[0].speculative_ids is not None
            else None
        )
Nicolas Patry's avatar
Nicolas Patry committed
622

623
624
625
        # Needed to avoid dropping blocks when the batches will go out of scope
        for b in batches:
            b.block_tables = None
626
            del b
627

628
        return cls(
629
630
            batch_id=batches[0].batch_id,
            requests=requests,
631
            requests_idx_mapping=requests_idx_mapping,
632
633
            input_ids=input_ids,
            position_ids=position_ids,
634
            cu_seqlen_prefill=None,
635
636
637
638
639
640
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
641
            max_seqlen=max_seqlen,
642
643
644
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
645
            input_lengths=input_lengths,
646
            input_lengths_tensor=input_lengths_tensor,
647
648
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
649
650
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
651
            next_token_chooser=next_token_chooser,
652
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
653
654
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
655
656
            blocks=blocks,
            max_blocks=max_blocks,
OlivierDehaene's avatar
OlivierDehaene committed
657
            speculative_ids=speculative_ids,
658
659
        )

660
661
662
    def __del__(self):
        if self.block_tables is not None and self.block_tables:
            # Free blocks
663
664
665
            get_cache_manager().free(
                list(itertools.chain.from_iterable(self.block_tables))
            )
666

667
668
669
670
671
672
673
    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
674
675
676
677
678
679
680
681
682
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        num_layers: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
683
        sliding_window: Optional[int] = None,
684
    ):
685
686
687
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
688

689
690
        self.cuda_graphs = {}

691
        super(FlashCausalLM, self).__init__(
692
            model=model,
693
694
695
696
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
697
698
            rank=rank,
            world_size=world_size,
699
            sliding_window=sliding_window,
700
701
702
703
704
705
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

706
707
708
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
709
        slots = torch.arange(bs, dtype=torch.int64, device=self.device)
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
        input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
        block_tables = (
            torch.arange(max_bt, dtype=torch.int32, device=self.device)
            .repeat(bs)
            .reshape((bs, max_bt))
        )
        kv_cache = get_cache_manager().kv_cache

        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "kv_cache": kv_cache,
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths,
        }
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs]["graph"] = graph

        torch.cuda.synchronize()
        # Run once outside to warmup
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=None,
            kv_cache=kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
            max_s=max_s,
            lm_head_indices=None,
        )
        torch.cuda.synchronize()

        with torch.cuda.graph(graph, pool=MEM_POOL):
745
            logits, speculative_logits = self.model.forward(
746
747
748
749
750
751
752
753
754
755
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                lm_head_indices=None,
            )
756
757
            self.cuda_graphs[bs]["logits"] = logits
            self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
758
759
        torch.cuda.synchronize()

760
    def warmup(self, batch: FlashCausalLMBatch):
761
        # The warmup batch is the biggest batch we could ever receive
762
763
764
765
        if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
            torch.cuda.empty_cache()
        elif IS_XPU_SYSTEM:
            torch.xpu.empty_cache()
766
        try:
767
            cache_manager = set_cache_manager(
768
                batch.blocks,
769
770
771
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
772
                self.sliding_window is not None,
773
774
775
                self.dtype,
                self.device,
            )
776
777
            max_bt = batch.max_blocks
            max_s = max_bt * get_cache_manager().block_size
778
            _, batch, _ = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
779
        except torch.cuda.OutOfMemoryError as e:
780
            raise RuntimeError(
781
782
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
783
            ) from e
784

785
786
787
788
        if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
            torch.cuda.synchronize(self.device)
        elif IS_XPU_SYSTEM:
            torch.xpu.synchronize(self.device)
789

790
791
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
792
793
794
795
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

796
797
        if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
            total_free_memory, _ = torch.cuda.mem_get_info(self.device)
Nicolas Patry's avatar
Nicolas Patry committed
798
799
800
            total_gpu_memory = torch.cuda.get_device_properties(
                self.device
            ).total_memory
801

802
803
804
805
806
            free_memory = max(
                0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
            )
        elif IS_XPU_SYSTEM:
            total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
Nicolas Patry's avatar
Nicolas Patry committed
807
            free_memory = int(total_gpu_memory * 0.5)
808
809
        else:
            raise NotImplementedError("FlashModel is only available on GPU")
810
811

        num_blocks = (
812
813
            # Leave 5% for some wiggle room
            int((free_memory * 0.95) // total_cache_size)
814
            # Add batch.blocks as we allocated it above, so it is included in the peak memory.
815
            + cache_manager.num_blocks
816
817
        )

818
        del batch
819
        del cache_manager
820

821
        set_cache_manager(
822
823
824
825
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
826
            self.sliding_window is not None,
827
828
829
830
            self.dtype,
            self.device,
        )

831
        if CUDA_GRAPHS:
832
            try:
833
                logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
834
                # Warmup cuda graphs
835
                for bs in CUDA_GRAPHS:
836
837
                    if self.speculate is None or self.speculate + 1 <= bs:
                        self.cuda_graph_warmup(bs, max_s, max_bt)
OlivierDehaene's avatar
OlivierDehaene committed
838
            except torch.cuda.OutOfMemoryError:
839
                logger.exception(f"Decode cuda graph warmup failed")
840
841
        else:
            logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
842

843
        return int(num_blocks * BLOCK_SIZE)
844

845
846
847
    def forward(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
848
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
849
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
850
851
852
853
854
855
856
857
858
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
859
860
861

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
862
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
863
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
864
865
866
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
867
868
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
869
870
871
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
872
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
OlivierDehaene's avatar
OlivierDehaene committed
873
874
875
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
876
877

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
878
879
880
881
882
883
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
884
885
886
887
888
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
889
890
891
892
893
894
895
896
897
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
898

899
        bs = input_ids.shape[0]
OlivierDehaene's avatar
OlivierDehaene committed
900
901
902
903
904
905
906
907
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None

        if cu_seqlen_prefill is not None or cuda_graph is None:
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
            return self.model.forward(
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=cu_seqlen_prefill,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                lm_head_indices=lm_head_indices,
            )

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
        cuda_graph["block_tables"][
            : block_tables.shape[0], : block_tables.shape[1]
        ] = block_tables
        cuda_graph["slots"].fill_(-1)
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths

        # Replay the graph
        cuda_graph["graph"].replay()
        # Slice output to the correct shape
935
936
937
938
939
940
941
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
942
943
944
945

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
946
947
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
948
        prefill = batch.cu_seqlen_prefill is not None
949
        prefill_logprobs = batch.prefill_next_token_indices is not None
950

951
952
        if batch.needed_blocks_slots:
            # Allocate blocks to this batch
953
954
955
956
957
958
959
960
961
962
            block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
                batch.needed_blocks_slots,
                batch.blocks,
                batch.max_blocks,
                batch.input_ids.device,
            )
            batch.needed_blocks_slots = None
            batch.block_tables = block_tables
            batch.block_tables_tensor = block_tables_tensor
            batch.slots = slots
963

964
        try:
965
            out, speculative_logits = self.forward(batch)
966
967
968
        except Exception as e:
            del batch
            raise e
969

970
971
        if prefill:
            next_token_logits = (
972
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
973
            )
Nicolas Patry's avatar
Nicolas Patry committed
974
975
            if speculative_logits is not None:
                speculative_logits = (
OlivierDehaene's avatar
OlivierDehaene committed
976
977
978
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
Nicolas Patry's avatar
Nicolas Patry committed
979
                )
980
981
982
        else:
            next_token_logits = out

Nicolas Patry's avatar
Nicolas Patry committed
983
        speculate = get_speculate()
OlivierDehaene's avatar
OlivierDehaene committed
984
985
986
987
988
989
990
991
992
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen],
            next_token_logits,
Nicolas Patry's avatar
Nicolas Patry committed
993
            speculate,
OlivierDehaene's avatar
OlivierDehaene committed
994
995
            batch.speculative_ids,
            speculative_logits,
996
997
        )

Nicolas Patry's avatar
Nicolas Patry committed
998
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
999
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1000
1001
        )

1002
        if prefill:
1003
            if len(batch) > 1 and prefill_logprobs:
1004
1005
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
1006
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
1007
1008

            next_position_ids = batch.position_ids.new_empty(len(batch))
1009
1010
1011
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
1012
1013
1014
1015
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

1016
1017
1018
1019
1020
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
1021
        stopped = True
1022
1023

        # Zipped iterator
OlivierDehaene's avatar
OlivierDehaene committed
1024
        iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
1025

1026
1027
1028
1029
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

1030
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1031
        index = 0
OlivierDehaene's avatar
OlivierDehaene committed
1032
        for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
1033
            # Indexing metadata
1034
1035
1036
            start_index = cumulative_length
            end_index = cumulative_length + input_length

1037
            if prefill:
1038
1039
1040
1041
1042
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

1043
1044
1045
1046
1047
1048
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
1049
1050
                if prefill_logprobs:
                    if len(batch) > 1:
drbh's avatar
drbh committed
1051
1052
1053
                        prefill_tokens_indices[out_start_index : out_end_index - 1] = (
                            batch.input_ids[start_index + 1 : start_index + out_length]
                        )
1054
1055
1056
1057
1058
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
1059

Nicolas Patry's avatar
Nicolas Patry committed
1060
1061
1062
            for j in range(n_accepted_ids):
                batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
                index += 1
1063
1064
1065

            cumulative_length += input_length

drbh's avatar
drbh committed
1066
        # Update values
Nicolas Patry's avatar
Nicolas Patry committed
1067
1068
1069
1070
1071
        batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
        batch.speculative_ids = speculative_ids
        batch.position_ids = next_position_ids + accepted_ids
        batch.input_lengths_tensor += accepted_ids
        batch.slot_indices += accepted_ids
1072

1073
        if prefill and prefill_logprobs:
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
1084
        next_token_ids = next_input_ids.tolist()
1085
1086
        accepted_ids = accepted_ids.tolist()
        start_decode = time.time_ns()
1087
1088
1089
1090
1091

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
1092
1093
            batch.prefix_offsets,
            batch.read_offsets,
1094
1095
            batch.stopping_criterias,
            batch.all_input_ids,
1096
1097
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
1098
            batch.top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1099
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1100
1101
            batch_top_token_ids,
            batch_top_token_logprobs,
1102
1103
1104
        )

        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1105
        index = 0
1106
1107
1108
        for i, (
            request,
            input_length,
1109
1110
            prefix_offset,
            read_offset,
1111
1112
            stopping_criteria,
            all_input_ids,
1113
1114
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
1115
            top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1116
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1117
1118
            top_token_ids,
            top_token_logprobs,
1119
        ) in enumerate(iterator):
1120
            # Append next token to all tokens
Nicolas Patry's avatar
Nicolas Patry committed
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
            next_token_texts = []
            left = 0

            current_stopped = False
            for j in range(index, index + n_accepted_ids):
                # Generated token
                next_token_id = next_token_ids[j]
                all_input_ids.append(next_token_id)
                next_token_text, prefix_offset, read_offset = self.decode_token(
                    all_input_ids,
                    prefix_offset,
                    read_offset,
                )
                next_token_texts.append(next_token_text)
1135

Nicolas Patry's avatar
Nicolas Patry committed
1136
1137
1138
1139
                stop, reason = stopping_criteria(
                    next_token_id,
                    next_token_text,
                )
1140

Nicolas Patry's avatar
Nicolas Patry committed
1141
1142
1143
1144
1145
1146
1147
                if stop:
                    left = index + n_accepted_ids - j - 1
                    current_stopped = True
                    break
                else:
                    current_stopped = False
            stopped = stopped and current_stopped
1148

OlivierDehaene's avatar
OlivierDehaene committed
1149
1150
1151
1152
            _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
            _next_token_logprobs = next_token_logprobs[
                index : index + n_accepted_ids - left
            ]
Nicolas Patry's avatar
Nicolas Patry committed
1153
            index += n_accepted_ids
1154

1155
1156
1157
1158
1159
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
1160
1161
                    output_text, _, _ = self.decode_token(
                        all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1162
1163
1164
1165
1166
1167
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
1168
1169
                    )
                    generated_text = GeneratedText(
1170
1171
1172
1173
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
1174
1175
1176
1177
1178
                    )
                else:
                    generated_text = None

                # Prefill
1179
1180
1181
1182
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

1183
1184
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
1185
                        out_start_index : out_end_index - 1
1186
1187
1188
1189
1190
1191
1192
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1193
1194

                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
1195
1196
1197
1198
                        prefill_token_ids,
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
1199
1200
1201
1202
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
1203
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
1204
                    all_top_tokens = []
drbh's avatar
drbh committed
1205
                    for top_token_ids, top_token_logprobs in zip(
1206
1207
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
1208
1209
1210
1211
1212
1213
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
1214
1215
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
1216
1217
1218
1219
1220
1221
1222
1223
1224
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
1225
1226
1227
                else:
                    top_tokens = None

1228
1229
1230
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1231
1232
1233
1234
1235
1236
                    Tokens(
                        _next_token_ids,
                        _next_token_logprobs,
                        next_token_texts,
                        [nid in self.all_special_ids for nid in _next_token_ids],
                    ),
1237
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
1238
                    top_tokens,
1239
1240
                )

1241
                generations.append(generation)
1242

drbh's avatar
drbh committed
1243
1244
1245
            # accept each new token for this specific request since we may
            # have more than one new token per request with speculative decoding
            for next_token_id in _next_token_ids:
OlivierDehaene's avatar
OlivierDehaene committed
1246
1247
1248
                batch.next_token_chooser = (
                    batch.next_token_chooser.advance_grammar_single(i, next_token_id)
                )
drbh's avatar
drbh committed
1249

1250
            # Update values
1251
            batch.input_lengths[i] = input_length + n_accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1252
1253
            if batch.input_lengths[i] > batch.max_seqlen:
                batch.max_seqlen = batch.input_lengths[i]
1254
1255
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1256
1257
            batch.all_input_ids[i] = all_input_ids

1258
1259
1260
        if stopped:
            del batch
            # No need to return a batch if we know that all requests stopped
1261
1262
1263
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
1264

1265
1266
1267
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1268

1269
1270
1271
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)