flash_causal_lm.py 45.5 KB
Newer Older
1
import math
2
import os
3
import time
4
import itertools
5
6
7
import torch
import torch.distributed

8
9
import numpy as np

10
from loguru import logger
11
12
from dataclasses import dataclass
from opentelemetry import trace
13
from transformers import PreTrainedTokenizerBase
14
from typing import Optional, Tuple, List, Type, Dict
15

OlivierDehaene's avatar
OlivierDehaene committed
16
from text_generation_server.models import Model
17
from text_generation_server.utils.tokens import batch_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
18
from text_generation_server.utils.speculate import get_speculate
19
20
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
21
    Tokens,
22
23
24
    Generation,
    GeneratedText,
)
25
26
27
28
29
from text_generation_server.models.cache_manager import (
    get_cache_manager,
    set_cache_manager,
    BLOCK_SIZE,
)
30
from text_generation_server.pb import generate_pb2
31
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
32
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
33
from text_generation_server.utils.dist import MEMORY_FRACTION
34
35
36

tracer = trace.get_tracer(__name__)

37

38
39
40
41
@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
42
43
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
44
45

    # Decoder values
46
47
    input_ids: torch.Tensor
    position_ids: torch.Tensor
Nicolas Patry's avatar
Nicolas Patry committed
48
    speculative_ids: torch.Tensor
49

50
51
52
53
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
54
55
56
57
58
59
60
61
62
63
64
65
66
67

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor
    # List of tuple of ints representing the number of blocks and slots needed by each sequence
    needed_blocks_slots: Optional[List[Tuple[int, int]]]

    # Set in prefill by the CacheManager
    # list of length b of list of length s_i // block_size
    block_tables: Optional[List[List[int]]]
68
    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
69
70
71
72
    block_tables_tensor: Optional[torch.Tensor]
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
    slots: Optional[torch.Tensor]

73
74
    max_seqlen: int

75
76
77
78
79
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

80
81
    # All tokens
    all_input_ids: List[List[int]]
82
    all_input_ids_tensor: torch.Tensor
83
84
85

    # Lengths of all generations present in the batch
    input_lengths: List[int]
86
    input_lengths_tensor: torch.Tensor
87
88
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
89
90

    # Generation helpers
91
    next_token_chooser: HeterogeneousNextTokenChooser
92
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
93
94
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
95

96
97
98
99
    # Number of blocks in this batch
    blocks: int
    # Maximum number of blocks
    max_blocks: int
100

101
102
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
103
            id=self.batch_id,
104
            request_ids=[r.id for r in self.requests],
105
            size=len(self),
106
            max_tokens=self.blocks * BLOCK_SIZE,
107
108
109
        )

    @classmethod
110
    def batch_tokenized_inputs(cls, requests, tokenizer):
111
112
        batch_inputs = []
        max_truncation = 0
113
        for r in requests:
114
115
116
117
118
119
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]
120
        return batch_tokenized_inputs
121

122
123
124
125
126
127
128
129
130
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
131
        position_ids = []
Nicolas Patry's avatar
Nicolas Patry committed
132
        speculative_ids = []
133
        cu_seqlen_prefill = [0]
134
135
136
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
137
138

        input_lengths = []
139
140
        prefix_offsets = []
        read_offsets = []
141
        all_input_ids = []
142
        requests_idx_mapping = {}
143

144
145
146
147
148
149
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

150
        next_token_chooser_parameters = []
151
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
152
        top_n_tokens = []
153
154
155

        # Cumulative length
        cumulative_length = 0
156
        cumulative_max_length = 0
157
        prefill_out_cumulative_length = 0
158

159
160
        blocks = 0
        max_seqlen = 0
161
        max_length = 0
162
        max_blocks = 0
163

164
        # Parse batch
165
166
167
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
168
169
170
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

171
            tokenized_input = tokenized_input[-r.truncate :]
huangwb's avatar
huangwb committed
172
            '''fix input s=1 crash bug
173
174
175
176
177
            if (
                tokenized_input[0] == tokenizer.bos_token_id
                and tokenized_input[1] == tokenizer.bos_token_id
            ):
                tokenized_input = tokenized_input[1:]
huangwb's avatar
huangwb committed
178
            '''
179

180
181
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
182

183
            prefix_offsets.append(input_length - 5)
184
            read_offsets.append(input_length)
185

186
            all_input_ids.append(tokenized_input)
187
188

            # Position ids
189
190
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
191
192

            # Add cumulative lengths of all previous inputs
193
            cu_seqlen_prefill.append(cumulative_length + input_length)
194

195
            next_token_chooser_parameters.append(r.parameters)
196

197
198
199
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
200
            max_new_tokens = stopping_criteria.max_new_tokens
201
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
202
            top_n_tokens.append(r.top_n_tokens)
203

204
205
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
206
207
            speculative_length = get_speculate()
            total_tokens = input_length + max_new_tokens - 1 + speculative_length
208
209
210
211
212
213
214
215
216
217
218
219
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
            blocks += needed_blocks
            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

240
241
            # Update
            cumulative_length += input_length
242
243
244
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
OlivierDehaene's avatar
OlivierDehaene committed
245
246
247
            max_length = max(
                max_length, input_length + max_new_tokens + speculative_length
            )
248
249

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
250
            next_token_chooser_parameters, dtype, device, tokenizer
251
        )
252
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
253
254
255
256
257
258
259

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
260

261
262
263
264
265
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

266
267
268
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
269
            slot_indices = torch.cat(slot_indices)
270
271
272
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
273
            slot_indices = slot_indices[0]
274

275
276
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
277
278
279
        )
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
280
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
281
282
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
283
        )
284

285
286
        if all_prefill_logprobs:
            prefill_head_indices = None
287
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
288
        elif no_prefill_logprobs:
289
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
290
291
292
293
294
295
296
297
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
Nicolas Patry's avatar
Nicolas Patry committed
298
299
300
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
301

302
303
304
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
305
            requests_idx_mapping=requests_idx_mapping,
306
307
            input_ids=input_ids,
            position_ids=position_ids,
308
            cu_seqlen_prefill=cu_seqlen_prefill,
309
310
311
312
313
314
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
315
            max_seqlen=max_seqlen,
316
317
318
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
319
            input_lengths=input_lengths,
320
            input_lengths_tensor=input_lengths_tensor,
321
322
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
323
            all_input_ids=all_input_ids,
324
325
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
326
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
327
328
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
329
330
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
331
            speculative_ids=None,
332
333
        )

334
    @tracer.start_as_current_span("filter")
335
336
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
337
338
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
339
        if len(request_ids) == len(self):
340
341
            return self

342
        device = self.input_ids.device
343

344
345
346
        # New values after filtering
        requests_idx_mapping = {}

347
348
349
        # Used to index into tensors
        indices = []

350
351
352
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
353
354
        )

355
        # Create on CPU to only move to GPU once instead of at every copy
356
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
357
358
        max_seqlen = 0

359
        requests = []
360
361
        start_slots = []
        block_tables = []
362
363
        all_input_ids = []

364
        input_lengths = []
365
366
        prefix_offsets = []
        read_offsets = []
367

368
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
369
        top_n_tokens = []
370

371
372
373
374
375
        blocks = 0
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

376
377
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
378
            indices.append(idx)
379
380
381
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
382
383
384
385

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
386

387
388
389
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
390
391
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
392

393
394
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
395

Nicolas Patry's avatar
Nicolas Patry committed
396
397
            top_n_tokens.append(self.top_n_tokens[idx])

398
            remaining_tokens = (
399
400
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
401

402
403
404
405
406
            request_block_table = self.block_tables[idx]
            blocks += len(request_block_table)
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

407
            # Copy to tensor (CPU)
408
            slot_indices[i] = cumulative_max_length + request_input_length - 1
409
410

            # Set slice
411
412
413
414
415
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
416
417
418
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
419

420
421
422
423
424
425
426
427
428
            max_blocks = max(max_blocks, len(request_block_table))

        block_indices_to_free = []
        # Iterate on all requests
        for i, r in enumerate(self.requests):
            # Filter requests that are not part of the new batch
            if r.id not in requests_idx_mapping.keys():
                block_indices_to_free.extend(self.block_tables[i])
        # Free blocks
429
        get_cache_manager().free(block_indices_to_free)
430
431
432
        # Needed to avoid dropping blocks when the batches will go out of scope
        self.block_tables = None

433
434
435
436
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
437
438
439
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
440
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
441
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
OlivierDehaene's avatar
OlivierDehaene committed
442
443
444
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
445
446

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
447

448
        # Move to GPU now that we have the whole tensor
449
        slot_indices = slot_indices.to(device)
450

451
        return type(self)(
452
453
454
455
456
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
457
            cu_seqlen_prefill=None,
458
459
460
461
462
463
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
464
            max_seqlen=max_seqlen,
465
466
467
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
468
            input_lengths=input_lengths,
469
            input_lengths_tensor=input_lengths_tensor,
470
471
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
472
473
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
474
            next_token_chooser=next_token_chooser,
475
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
476
477
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
478
479
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
480
            speculative_ids=speculative_ids,
481
482
483
484
485
486
487
488
489
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

490
491
492
493
494
495
496
497
498
499
        blocks = 0
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
            blocks += b.blocks
OlivierDehaene's avatar
OlivierDehaene committed
500
501
502
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
503
504
505
506
507
508
509
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
510
                    + speculative_length
511
512
513
514
515
516
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
517
518
519

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
520
521
522
523
524
525
526
527
528
529
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
530
        )
Nicolas Patry's avatar
Nicolas Patry committed
531
532
533
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
534

535
536
        start_slots = []
        block_tables = []
537
538
539
        all_input_ids = []

        input_lengths = []
540
541
        prefix_offsets = []
        read_offsets = []
542

543
        next_token_chooser_parameters = []
544
        fsm_grammar_states = []
545
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
546
        top_n_tokens = []
547

548
        # Cumulative length
549
        cumulative_batch_size = 0
550
        cumulative_slots = 0
551
552
553

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
554
555
556
557
558
559
560
561

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

562
563
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
564
565
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
566
567
568
569

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
570
571
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
572
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
573
            slots[slots_start_index:slots_end_index] = batch.slots
574

575
576
577
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
578

579
580
581
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
582

583
584
585
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
586
587
            all_input_ids.extend(batch.all_input_ids)

588
            input_lengths.extend(batch.input_lengths)
589
590
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
591

592
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
593
            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
594
595
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
596
597
            top_n_tokens.extend(batch.top_n_tokens)

598
            # Update
599
            cumulative_batch_size += len(batch)
600
            cumulative_slots += len(batch.slots)
601

602
        start_slots = torch.concat(start_slots)
603

604
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
605
606
607
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
drbh's avatar
drbh committed
608
            tokenizer=batches[0].next_token_chooser.tokenizer,
609
            fsm_grammar_states=fsm_grammar_states,
610
611
        )

OlivierDehaene's avatar
OlivierDehaene committed
612
613
614
615
616
        speculative_ids = (
            torch.cat([b.speculative_ids for b in batches], dim=0)
            if batches[0].speculative_ids is not None
            else None
        )
Nicolas Patry's avatar
Nicolas Patry committed
617

618
619
620
        # Needed to avoid dropping blocks when the batches will go out of scope
        for b in batches:
            b.block_tables = None
621
            del b
622

623
        return cls(
624
625
            batch_id=batches[0].batch_id,
            requests=requests,
626
            requests_idx_mapping=requests_idx_mapping,
627
628
            input_ids=input_ids,
            position_ids=position_ids,
629
            cu_seqlen_prefill=None,
630
631
632
633
634
635
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
636
            max_seqlen=max_seqlen,
637
638
639
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
640
            input_lengths=input_lengths,
641
            input_lengths_tensor=input_lengths_tensor,
642
643
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
644
645
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
646
            next_token_chooser=next_token_chooser,
647
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
648
649
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
650
651
            blocks=blocks,
            max_blocks=max_blocks,
OlivierDehaene's avatar
OlivierDehaene committed
652
            speculative_ids=speculative_ids,
653
654
        )

655
656
657
    def __del__(self):
        if self.block_tables is not None and self.block_tables:
            # Free blocks
658
659
660
            get_cache_manager().free(
                list(itertools.chain.from_iterable(self.block_tables))
            )
661

662
663
664
665
666
667
668
    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
669
670
671
672
673
674
675
676
677
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        num_layers: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
678
        sliding_window: Optional[int] = None,
679
    ):
680
681
682
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
683

684
685
        self.cuda_graphs = {}

686
        super(FlashCausalLM, self).__init__(
687
            model=model,
688
689
690
691
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
692
693
            rank=rank,
            world_size=world_size,
694
            sliding_window=sliding_window,
695
696
697
698
699
700
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

701
702
703
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
704
        slots = torch.arange(bs, dtype=torch.int64, device=self.device)
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
        input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
        block_tables = (
            torch.arange(max_bt, dtype=torch.int32, device=self.device)
            .repeat(bs)
            .reshape((bs, max_bt))
        )
        kv_cache = get_cache_manager().kv_cache

        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "kv_cache": kv_cache,
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths,
        }
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs]["graph"] = graph

        torch.cuda.synchronize()
        # Run once outside to warmup
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=None,
            kv_cache=kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
            max_s=max_s,
            lm_head_indices=None,
        )
        torch.cuda.synchronize()

        with torch.cuda.graph(graph, pool=MEM_POOL):
740
            logits, speculative_logits = self.model.forward(
741
742
743
744
745
746
747
748
749
750
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                lm_head_indices=None,
            )
751
752
            self.cuda_graphs[bs]["logits"] = logits
            self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
753
754
        torch.cuda.synchronize()

755
    def warmup(self, batch: FlashCausalLMBatch):
756
        # The warmup batch is the biggest batch we could ever receive
757
758
        torch.cuda.empty_cache()
        try:
759
            cache_manager = set_cache_manager(
760
                batch.blocks,
761
762
763
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
764
                self.sliding_window is not None,
765
766
767
                self.dtype,
                self.device,
            )
768
769
            max_bt = batch.max_blocks
            max_s = max_bt * get_cache_manager().block_size
770
            _, batch, _ = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
771
        except torch.cuda.OutOfMemoryError as e:
772
            raise RuntimeError(
773
774
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
775
            ) from e
776
777
778

        torch.cuda.synchronize(self.device)

779
780
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
781
782
783
784
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

785
786
787
788
789
790
        total_free_memory, _ = torch.cuda.mem_get_info(self.device)
        total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory

        free_memory = max(
            0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
        )
791
792

        num_blocks = (
793
794
            # Leave 5% for some wiggle room
            int((free_memory * 0.95) // total_cache_size)
795
            # Add batch.blocks as we allocated it above, so it is included in the peak memory.
796
            + cache_manager.num_blocks
797
798
        )

799
        del batch
800
        del cache_manager
801

802
        set_cache_manager(
803
804
805
806
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
807
            self.sliding_window is not None,
808
809
810
811
            self.dtype,
            self.device,
        )

812
        if CUDA_GRAPHS:
813
            try:
814
                logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
815
                # Warmup cuda graphs
816
                for bs in CUDA_GRAPHS:
817
818
                    if self.speculate is None or self.speculate + 1 <= bs:
                        self.cuda_graph_warmup(bs, max_s, max_bt)
OlivierDehaene's avatar
OlivierDehaene committed
819
            except torch.cuda.OutOfMemoryError:
820
821
                logger.exception(f"Decode cuda graph warmup failed")

822
        return int(num_blocks * BLOCK_SIZE)
823

824
825
826
    def forward(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
827
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
828
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
829
830
831
832
833
834
835
836
837
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
838
839
840

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
841
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
842
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
843
844
845
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
846
847
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
848
849
850
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
851
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
OlivierDehaene's avatar
OlivierDehaene committed
852
853
854
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
855
856

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
857
858
859
860
861
862
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
863
864
865
866
867
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
868
869
870
871
872
873
874
875
876
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
877

878
        bs = input_ids.shape[0]
OlivierDehaene's avatar
OlivierDehaene committed
879
880
881
882
883
884
885
886
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None

        if cu_seqlen_prefill is not None or cuda_graph is None:
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
            return self.model.forward(
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=cu_seqlen_prefill,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                lm_head_indices=lm_head_indices,
            )

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
        cuda_graph["block_tables"][
            : block_tables.shape[0], : block_tables.shape[1]
        ] = block_tables
        cuda_graph["slots"].fill_(-1)
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths

        # Replay the graph
        cuda_graph["graph"].replay()
        # Slice output to the correct shape
914
915
916
917
918
919
920
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
921
922
923
924

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
925
926
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
927
        prefill = batch.cu_seqlen_prefill is not None
928
        prefill_logprobs = batch.prefill_next_token_indices is not None
929

930
931
        if batch.needed_blocks_slots:
            # Allocate blocks to this batch
932
933
934
935
936
937
938
939
940
941
            block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
                batch.needed_blocks_slots,
                batch.blocks,
                batch.max_blocks,
                batch.input_ids.device,
            )
            batch.needed_blocks_slots = None
            batch.block_tables = block_tables
            batch.block_tables_tensor = block_tables_tensor
            batch.slots = slots
942

943
        try:
944
            out, speculative_logits = self.forward(batch)
945
946
947
        except Exception as e:
            del batch
            raise e
948

949
950
        if prefill:
            next_token_logits = (
951
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
952
            )
Nicolas Patry's avatar
Nicolas Patry committed
953
954
            if speculative_logits is not None:
                speculative_logits = (
OlivierDehaene's avatar
OlivierDehaene committed
955
956
957
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
Nicolas Patry's avatar
Nicolas Patry committed
958
                )
959
960
961
        else:
            next_token_logits = out

Nicolas Patry's avatar
Nicolas Patry committed
962
        speculate = get_speculate()
OlivierDehaene's avatar
OlivierDehaene committed
963
964
965
966
967
968
969
970
971
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen],
            next_token_logits,
Nicolas Patry's avatar
Nicolas Patry committed
972
            speculate,
OlivierDehaene's avatar
OlivierDehaene committed
973
974
            batch.speculative_ids,
            speculative_logits,
975
976
        )

Nicolas Patry's avatar
Nicolas Patry committed
977
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
978
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
979
980
        )

981
        if prefill:
982
            if len(batch) > 1 and prefill_logprobs:
983
984
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
985
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
986
987

            next_position_ids = batch.position_ids.new_empty(len(batch))
988
989
990
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
991
992
993
994
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

995
996
997
998
999
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
1000
        stopped = True
1001
1002

        # Zipped iterator
OlivierDehaene's avatar
OlivierDehaene committed
1003
        iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
1004

1005
1006
1007
1008
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

1009
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1010
        index = 0
OlivierDehaene's avatar
OlivierDehaene committed
1011
        for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
1012
            # Indexing metadata
1013
1014
1015
            start_index = cumulative_length
            end_index = cumulative_length + input_length

1016
            if prefill:
1017
1018
1019
1020
1021
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

1022
1023
1024
1025
1026
1027
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
1028
1029
                if prefill_logprobs:
                    if len(batch) > 1:
drbh's avatar
drbh committed
1030
1031
1032
                        prefill_tokens_indices[out_start_index : out_end_index - 1] = (
                            batch.input_ids[start_index + 1 : start_index + out_length]
                        )
1033
1034
1035
1036
1037
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
1038

Nicolas Patry's avatar
Nicolas Patry committed
1039
1040
1041
            for j in range(n_accepted_ids):
                batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
                index += 1
1042
1043
1044

            cumulative_length += input_length

drbh's avatar
drbh committed
1045
        # Update values
Nicolas Patry's avatar
Nicolas Patry committed
1046
1047
1048
1049
1050
        batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
        batch.speculative_ids = speculative_ids
        batch.position_ids = next_position_ids + accepted_ids
        batch.input_lengths_tensor += accepted_ids
        batch.slot_indices += accepted_ids
1051

1052
        if prefill and prefill_logprobs:
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
1063
        next_token_ids = next_input_ids.tolist()
1064
1065
        accepted_ids = accepted_ids.tolist()
        start_decode = time.time_ns()
1066
1067
1068
1069
1070

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
1071
1072
            batch.prefix_offsets,
            batch.read_offsets,
1073
1074
            batch.stopping_criterias,
            batch.all_input_ids,
1075
1076
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
1077
            batch.top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1078
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1079
1080
            batch_top_token_ids,
            batch_top_token_logprobs,
1081
1082
1083
        )

        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1084
        index = 0
1085
1086
1087
        for i, (
            request,
            input_length,
1088
1089
            prefix_offset,
            read_offset,
1090
1091
            stopping_criteria,
            all_input_ids,
1092
1093
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
1094
            top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1095
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1096
1097
            top_token_ids,
            top_token_logprobs,
1098
        ) in enumerate(iterator):
1099
            # Append next token to all tokens
Nicolas Patry's avatar
Nicolas Patry committed
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
            next_token_texts = []
            left = 0

            current_stopped = False
            for j in range(index, index + n_accepted_ids):
                # Generated token
                next_token_id = next_token_ids[j]
                all_input_ids.append(next_token_id)
                next_token_text, prefix_offset, read_offset = self.decode_token(
                    all_input_ids,
                    prefix_offset,
                    read_offset,
                )
                next_token_texts.append(next_token_text)
1114

Nicolas Patry's avatar
Nicolas Patry committed
1115
1116
1117
1118
                stop, reason = stopping_criteria(
                    next_token_id,
                    next_token_text,
                )
1119

Nicolas Patry's avatar
Nicolas Patry committed
1120
1121
1122
1123
1124
1125
1126
                if stop:
                    left = index + n_accepted_ids - j - 1
                    current_stopped = True
                    break
                else:
                    current_stopped = False
            stopped = stopped and current_stopped
1127

OlivierDehaene's avatar
OlivierDehaene committed
1128
1129
1130
1131
            _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
            _next_token_logprobs = next_token_logprobs[
                index : index + n_accepted_ids - left
            ]
Nicolas Patry's avatar
Nicolas Patry committed
1132
            index += n_accepted_ids
1133

1134
1135
1136
1137
1138
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
1139
1140
                    output_text, _, _ = self.decode_token(
                        all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1141
1142
1143
1144
1145
1146
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
1147
1148
                    )
                    generated_text = GeneratedText(
1149
1150
1151
1152
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
1153
1154
1155
1156
1157
                    )
                else:
                    generated_text = None

                # Prefill
1158
1159
1160
1161
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

1162
1163
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
1164
                        out_start_index : out_end_index - 1
1165
1166
1167
1168
1169
1170
1171
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1172
1173

                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
1174
1175
1176
1177
                        prefill_token_ids,
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
1178
1179
1180
1181
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
1182
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
1183
                    all_top_tokens = []
drbh's avatar
drbh committed
1184
                    for top_token_ids, top_token_logprobs in zip(
1185
1186
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
1187
1188
1189
1190
1191
1192
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
1193
1194
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
1195
1196
1197
1198
1199
1200
1201
1202
1203
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
1204
1205
1206
                else:
                    top_tokens = None

1207
1208
1209
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1210
1211
1212
1213
1214
1215
                    Tokens(
                        _next_token_ids,
                        _next_token_logprobs,
                        next_token_texts,
                        [nid in self.all_special_ids for nid in _next_token_ids],
                    ),
1216
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
1217
                    top_tokens,
1218
1219
                )

1220
                generations.append(generation)
1221

drbh's avatar
drbh committed
1222
1223
1224
            # accept each new token for this specific request since we may
            # have more than one new token per request with speculative decoding
            for next_token_id in _next_token_ids:
OlivierDehaene's avatar
OlivierDehaene committed
1225
1226
1227
                batch.next_token_chooser = (
                    batch.next_token_chooser.advance_grammar_single(i, next_token_id)
                )
drbh's avatar
drbh committed
1228

1229
            # Update values
1230
            batch.input_lengths[i] = input_length + n_accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1231
1232
            if batch.input_lengths[i] > batch.max_seqlen:
                batch.max_seqlen = batch.input_lengths[i]
1233
1234
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1235
1236
            batch.all_input_ids[i] = all_input_ids

1237
1238
1239
        if stopped:
            del batch
            # No need to return a batch if we know that all requests stopped
1240
1241
1242
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
1243

1244
1245
1246
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1247

1248
1249
1250
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)