flash_causal_lm.py 49.8 KB
Newer Older
1
import math
2
import os
3
import time
4
import itertools
5
6
7
import torch
import torch.distributed

8
9
import numpy as np

10
from loguru import logger
11
12
from dataclasses import dataclass
from opentelemetry import trace
13
from transformers import PreTrainedTokenizerBase
14
from typing import Optional, Tuple, List, Type, Dict
fxmarty's avatar
fxmarty committed
15
16
17

from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.import_utils import SYSTEM
OlivierDehaene's avatar
OlivierDehaene committed
18
from text_generation_server.models import Model
19
from text_generation_server.utils.tokens import batch_top_tokens
20
from text_generation_server.utils.dist import RANK
Nicolas Patry's avatar
Nicolas Patry committed
21
from text_generation_server.utils.speculate import get_speculate
22
23
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
24
    Tokens,
25
26
27
    Generation,
    GeneratedText,
)
28
29
30
31
32
from text_generation_server.models.cache_manager import (
    get_cache_manager,
    set_cache_manager,
    BLOCK_SIZE,
)
33
from text_generation_server.pb import generate_pb2
34
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
fxmarty's avatar
fxmarty committed
35
import text_generation_server.models.globals as tgi_globals
36
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
37
from text_generation_server.utils.dist import MEMORY_FRACTION
38

Nicolas Patry's avatar
Nicolas Patry committed
39
from text_generation_server.utils.import_utils import (
Nicolas Patry's avatar
Nicolas Patry committed
40
41
42
    empty_cache,
    synchronize,
    get_free_memory,
Nicolas Patry's avatar
Nicolas Patry committed
43
44
)

Nicolas Patry's avatar
Nicolas Patry committed
45
46
tracer = trace.get_tracer(__name__)

47

48
49
50
51
@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
52
53
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
54
55

    # Decoder values
56
57
    input_ids: torch.Tensor
    position_ids: torch.Tensor
Nicolas Patry's avatar
Nicolas Patry committed
58
    speculative_ids: torch.Tensor
59

60
61
62
63
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
64
65
66
67
68
69
70
71
72
73
74
75
76
77

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor
    # List of tuple of ints representing the number of blocks and slots needed by each sequence
    needed_blocks_slots: Optional[List[Tuple[int, int]]]

    # Set in prefill by the CacheManager
    # list of length b of list of length s_i // block_size
    block_tables: Optional[List[List[int]]]
78
    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
79
80
81
82
    block_tables_tensor: Optional[torch.Tensor]
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
    slots: Optional[torch.Tensor]

83
84
    max_seqlen: int

85
86
87
88
89
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

90
91
    # All tokens
    all_input_ids: List[List[int]]
92
    all_input_ids_tensor: torch.Tensor
93
94
95

    # Lengths of all generations present in the batch
    input_lengths: List[int]
96
    input_lengths_tensor: torch.Tensor
97
98
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
99
100

    # Generation helpers
101
    next_token_chooser: HeterogeneousNextTokenChooser
102
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
103
104
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
105

106
107
108
109
    # Number of blocks in this batch
    blocks: int
    # Maximum number of blocks
    max_blocks: int
110

111
112
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
113
            id=self.batch_id,
114
            request_ids=[r.id for r in self.requests],
115
            size=len(self),
116
            max_tokens=self.blocks * BLOCK_SIZE,
117
118
119
        )

    @classmethod
120
    def batch_tokenized_inputs(cls, requests, tokenizer):
121
122
        batch_inputs = []
        max_truncation = 0
123
        for r in requests:
124
125
126
127
128
129
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]
130
        return batch_tokenized_inputs
131

132
133
134
135
136
137
138
139
140
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
drbh's avatar
drbh committed
141
142
143
144
145
146
147
148
149
150
151
        return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)

    @classmethod
    def from_tokenized(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        batch_tokenized_inputs,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
152
        position_ids = []
Nicolas Patry's avatar
Nicolas Patry committed
153
        speculative_ids = []
154
        cu_seqlen_prefill = [0]
155
156
157
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
158
159

        input_lengths = []
160
161
        prefix_offsets = []
        read_offsets = []
162
        all_input_ids = []
163
        requests_idx_mapping = {}
164

165
166
167
168
169
170
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

171
        next_token_chooser_parameters = []
172
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
173
        top_n_tokens = []
174
175
176

        # Cumulative length
        cumulative_length = 0
177
        cumulative_max_length = 0
178
        prefill_out_cumulative_length = 0
179

180
181
        blocks = 0
        max_seqlen = 0
182
        max_length = 0
183
        max_blocks = 0
184

185
        # Parse batch
186
187
188
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
189
190
191
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

192
            tokenized_input = tokenized_input[-r.truncate :]
193
194
195
196
197
            if (
                tokenized_input[0] == tokenizer.bos_token_id
                and tokenized_input[1] == tokenizer.bos_token_id
            ):
                tokenized_input = tokenized_input[1:]
198

199
200
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
201

202
            prefix_offsets.append(input_length - 5)
203
            read_offsets.append(input_length)
204

205
            all_input_ids.append(tokenized_input)
206
207

            # Position ids
208
209
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
210
211

            # Add cumulative lengths of all previous inputs
212
            cu_seqlen_prefill.append(cumulative_length + input_length)
213

214
            next_token_chooser_parameters.append(r.parameters)
215

216
217
218
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
219
            max_new_tokens = stopping_criteria.max_new_tokens
220
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
221
            top_n_tokens.append(r.top_n_tokens)
222

223
224
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
225
            speculative_length = get_speculate()
drbh's avatar
drbh committed
226
            speculative_length = 0 if speculative_length is None else speculative_length
Nicolas Patry's avatar
Nicolas Patry committed
227
            total_tokens = input_length + max_new_tokens - 1 + speculative_length
228
229
230
231
232
233
234
235
236
237
238
239
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
            blocks += needed_blocks
            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

260
261
            # Update
            cumulative_length += input_length
262
263
264
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
OlivierDehaene's avatar
OlivierDehaene committed
265
266
267
            max_length = max(
                max_length, input_length + max_new_tokens + speculative_length
            )
268
269

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
270
            next_token_chooser_parameters, dtype, device, tokenizer
271
        )
272
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
273
274
275
276
277
278
279

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
280

281
282
283
284
285
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

286
287
288
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
289
            slot_indices = torch.cat(slot_indices)
290
291
292
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
293
            slot_indices = slot_indices[0]
294

295
296
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
297
298
299
        )
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
300
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
301
302
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
303
        )
304

305
306
        if all_prefill_logprobs:
            prefill_head_indices = None
307
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
308
        elif no_prefill_logprobs:
309
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
310
311
312
313
314
315
316
317
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
Nicolas Patry's avatar
Nicolas Patry committed
318
319
320
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
321

322
323
324
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
325
            requests_idx_mapping=requests_idx_mapping,
326
327
            input_ids=input_ids,
            position_ids=position_ids,
328
            cu_seqlen_prefill=cu_seqlen_prefill,
329
330
331
332
333
334
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
335
            max_seqlen=max_seqlen,
336
337
338
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
339
            input_lengths=input_lengths,
340
            input_lengths_tensor=input_lengths_tensor,
341
342
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
343
            all_input_ids=all_input_ids,
344
345
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
346
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
347
348
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
349
350
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
351
            speculative_ids=None,
352
353
        )

354
    @tracer.start_as_current_span("filter")
355
356
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
357
358
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
359
        if len(request_ids) == len(self):
360
361
            return self

362
        device = self.input_ids.device
363

364
365
366
        # New values after filtering
        requests_idx_mapping = {}

367
368
369
        # Used to index into tensors
        indices = []

370
371
372
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
373
374
        )

375
        # Create on CPU to only move to GPU once instead of at every copy
376
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
377
378
        max_seqlen = 0

379
        requests = []
380
381
        start_slots = []
        block_tables = []
382
383
        all_input_ids = []

384
        input_lengths = []
385
386
        prefix_offsets = []
        read_offsets = []
387

388
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
389
        top_n_tokens = []
390

391
392
393
394
395
        blocks = 0
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

396
397
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
398
            indices.append(idx)
399
400
401
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
402
403
404
405

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
406

407
408
409
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
410
411
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
412

413
414
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
415

Nicolas Patry's avatar
Nicolas Patry committed
416
417
            top_n_tokens.append(self.top_n_tokens[idx])

418
            remaining_tokens = (
419
420
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
421

422
423
424
425
426
            request_block_table = self.block_tables[idx]
            blocks += len(request_block_table)
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

427
            # Copy to tensor (CPU)
428
            slot_indices[i] = cumulative_max_length + request_input_length - 1
429
430

            # Set slice
431
432
433
434
435
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
436
437
438
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
439

440
441
442
443
444
445
446
447
448
            max_blocks = max(max_blocks, len(request_block_table))

        block_indices_to_free = []
        # Iterate on all requests
        for i, r in enumerate(self.requests):
            # Filter requests that are not part of the new batch
            if r.id not in requests_idx_mapping.keys():
                block_indices_to_free.extend(self.block_tables[i])
        # Free blocks
449
        get_cache_manager().free(block_indices_to_free)
450
451
452
        # Needed to avoid dropping blocks when the batches will go out of scope
        self.block_tables = None

453
454
455
456
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
457
458
459
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
460
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
461
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
OlivierDehaene's avatar
OlivierDehaene committed
462
463
464
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
465
466

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
467

468
        # Move to GPU now that we have the whole tensor
469
        slot_indices = slot_indices.to(device)
470

471
        return type(self)(
472
473
474
475
476
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
477
            cu_seqlen_prefill=None,
478
479
480
481
482
483
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
484
            max_seqlen=max_seqlen,
485
486
487
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
488
            input_lengths=input_lengths,
489
            input_lengths_tensor=input_lengths_tensor,
490
491
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
492
493
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
494
            next_token_chooser=next_token_chooser,
495
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
496
497
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
498
499
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
500
            speculative_ids=speculative_ids,
501
502
503
504
505
506
507
508
509
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

510
511
512
513
514
515
516
517
518
519
        blocks = 0
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
            blocks += b.blocks
OlivierDehaene's avatar
OlivierDehaene committed
520
521
522
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
523
524
525
526
527
528
529
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
530
                    + speculative_length
531
532
533
534
535
536
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
537
538
539

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
540
541
542
543
544
545
546
547
548
549
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
550
        )
Nicolas Patry's avatar
Nicolas Patry committed
551
552
553
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
554

555
556
        start_slots = []
        block_tables = []
557
558
559
        all_input_ids = []

        input_lengths = []
560
561
        prefix_offsets = []
        read_offsets = []
562

563
        next_token_chooser_parameters = []
564
        fsm_grammar_states = []
565
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
566
        top_n_tokens = []
567

568
        # Cumulative length
569
        cumulative_batch_size = 0
570
        cumulative_slots = 0
571
572
573

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
574
575
576
577
578
579
580
581

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

582
583
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
584
585
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
586
587
588
589

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
590
591
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
592
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
593
            slots[slots_start_index:slots_end_index] = batch.slots
594

595
596
597
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
598

599
600
601
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
602

603
604
605
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
606
607
            all_input_ids.extend(batch.all_input_ids)

608
            input_lengths.extend(batch.input_lengths)
609
610
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
611

612
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
613
            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
614
615
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
616
617
            top_n_tokens.extend(batch.top_n_tokens)

618
            # Update
619
            cumulative_batch_size += len(batch)
620
            cumulative_slots += len(batch.slots)
621

622
        start_slots = torch.concat(start_slots)
623

624
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
625
626
627
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
drbh's avatar
drbh committed
628
            tokenizer=batches[0].next_token_chooser.tokenizer,
629
            fsm_grammar_states=fsm_grammar_states,
630
631
        )

OlivierDehaene's avatar
OlivierDehaene committed
632
633
634
635
636
        speculative_ids = (
            torch.cat([b.speculative_ids for b in batches], dim=0)
            if batches[0].speculative_ids is not None
            else None
        )
Nicolas Patry's avatar
Nicolas Patry committed
637

638
639
640
        # Needed to avoid dropping blocks when the batches will go out of scope
        for b in batches:
            b.block_tables = None
641
            del b
642

643
        return cls(
644
645
            batch_id=batches[0].batch_id,
            requests=requests,
646
            requests_idx_mapping=requests_idx_mapping,
647
648
            input_ids=input_ids,
            position_ids=position_ids,
649
            cu_seqlen_prefill=None,
650
651
652
653
654
655
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
656
            max_seqlen=max_seqlen,
657
658
659
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
660
            input_lengths=input_lengths,
661
            input_lengths_tensor=input_lengths_tensor,
662
663
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
664
665
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
666
            next_token_chooser=next_token_chooser,
667
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
668
669
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
670
671
            blocks=blocks,
            max_blocks=max_blocks,
OlivierDehaene's avatar
OlivierDehaene committed
672
            speculative_ids=speculative_ids,
673
674
        )

675
676
677
    def __del__(self):
        if self.block_tables is not None and self.block_tables:
            # Free blocks
678
679
680
            get_cache_manager().free(
                list(itertools.chain.from_iterable(self.block_tables))
            )
681

682
683
684
685
686
687
688
    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
689
690
691
692
693
694
695
696
697
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        num_layers: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
698
        sliding_window: Optional[int] = None,
699
    ):
700
701
702
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
703

704
705
        self.cuda_graphs = {}

706
        super(FlashCausalLM, self).__init__(
707
            model=model,
708
709
710
711
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
712
713
            rank=rank,
            world_size=world_size,
714
            sliding_window=sliding_window,
715
716
717
718
719
720
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

721
722
723
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
724
        slots = torch.arange(bs, dtype=torch.int64, device=self.device)
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
        input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
        block_tables = (
            torch.arange(max_bt, dtype=torch.int32, device=self.device)
            .repeat(bs)
            .reshape((bs, max_bt))
        )
        kv_cache = get_cache_manager().kv_cache

        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "kv_cache": kv_cache,
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths,
        }
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs]["graph"] = graph

        torch.cuda.synchronize()
        # Run once outside to warmup
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=None,
            kv_cache=kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
            max_s=max_s,
            lm_head_indices=None,
        )
        torch.cuda.synchronize()

        with torch.cuda.graph(graph, pool=MEM_POOL):
760
            logits, speculative_logits = self.model.forward(
761
762
763
764
765
766
767
768
769
770
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                lm_head_indices=None,
            )
771
772
            self.cuda_graphs[bs]["logits"] = logits
            self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
773
774
        torch.cuda.synchronize()

775
    def warmup(self, batch: FlashCausalLMBatch):
776
        # The warmup batch is the biggest batch we could ever receive
Nicolas Patry's avatar
Nicolas Patry committed
777
778
        empty_cache()

779
        try:
780
            cache_manager = set_cache_manager(
781
                batch.blocks,
782
783
784
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
785
                self.sliding_window is not None,
786
787
788
                self.dtype,
                self.device,
            )
789
790
            max_bt = batch.max_blocks
            max_s = max_bt * get_cache_manager().block_size
fxmarty's avatar
fxmarty committed
791
792
793

            if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
                torch.cuda.tunable.tuning_enable(False)
794
            _, batch, _ = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
795
        except torch.cuda.OutOfMemoryError as e:
796
            raise RuntimeError(
797
798
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
799
            ) from e
800

Nicolas Patry's avatar
Nicolas Patry committed
801
        synchronize(self.device)
802

803
804
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
805
806
807
808
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

Nicolas Patry's avatar
Nicolas Patry committed
809
        free_memory = get_free_memory(self.device, MEMORY_FRACTION)
810
811

        num_blocks = (
812
813
            # Leave 5% for some wiggle room
            int((free_memory * 0.95) // total_cache_size)
814
            # Add batch.blocks as we allocated it above, so it is included in the peak memory.
815
            + cache_manager.num_blocks
816
817
        )

818
        del batch
819
        del cache_manager
820

821
        set_cache_manager(
822
823
824
825
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
826
            self.sliding_window is not None,
827
828
829
830
            self.dtype,
            self.device,
        )

fxmarty's avatar
fxmarty committed
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
        if SYSTEM == "rocm":
            if (
                os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
                or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
            ):
                if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
                    torch.cuda.tunable.tuning_enable(True)

                if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None:
                    tuning_sequences = [
                        int(val)
                        for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
                    ]
                else:
                    tuning_sequences = CUDA_GRAPHS

                tunableop_filepath = os.path.join(
                    HUGGINGFACE_HUB_CACHE,
                    f"tunableop_{tgi_globals.MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
                )

                logger.info(
                    f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`."
                )

                if os.path.isfile(tunableop_filepath):
                    logger.info(
                        f"The file {tunableop_filepath} already exists and will be reused."
                    )
                    torch.cuda.tunable.read_file(tunableop_filepath)

                os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)

                for seqlen in tuning_sequences:
                    logger.info(f"Warming up TunableOp for seqlen={seqlen}")
                    self.tunableop_warmup(seqlen)
                    torch.cuda.tunable.write_file(tunableop_filepath)
                torch.cuda.tunable.tuning_enable(False)
            else:
                logger.info(
                    "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp."
                )

874
        if CUDA_GRAPHS:
875
            try:
876
                logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
877
                # Warmup cuda graphs
878
                for bs in CUDA_GRAPHS:
879
880
                    if self.speculate is None or self.speculate + 1 <= bs:
                        self.cuda_graph_warmup(bs, max_s, max_bt)
OlivierDehaene's avatar
OlivierDehaene committed
881
            except torch.cuda.OutOfMemoryError:
882
                logger.exception(f"Decode cuda graph warmup failed")
883
884
        else:
            logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
885

886
        return int(num_blocks * BLOCK_SIZE)
887

fxmarty's avatar
fxmarty committed
888
889
890
891
892
893
    def tunableop_warmup(self, seqlen: int):
        input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
        slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
        kv_cache = get_cache_manager().kv_cache

fxmarty's avatar
fxmarty committed
894
895
896
        # Dummy value, some models (starcoder2) don't accept `None`.
        input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)

fxmarty's avatar
fxmarty committed
897
898
899
900
901
902
903
904
905
        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=torch.tensor(
                [0, seqlen], device=self.device, dtype=torch.int32
            ),
            kv_cache=get_cache_manager().kv_cache,
            block_tables=None,
fxmarty's avatar
fxmarty committed
906
            input_lengths=input_lengths,
fxmarty's avatar
fxmarty committed
907
908
909
910
911
            slots=slots,
            max_s=seqlen,
            lm_head_indices=None,
        )

912
913
914
    def forward(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
915
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
916
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
917
918
919
920
921
922
923
924
925
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
926
927
928

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
929
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
930
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
931
932
933
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
934
935
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
936
937
938
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
939
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
OlivierDehaene's avatar
OlivierDehaene committed
940
941
942
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
943
944

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
945
946
947
948
949
950
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
951
952
953
954
955
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
956
957
958
959
960
961
962
963
964
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
965

966
        bs = input_ids.shape[0]
OlivierDehaene's avatar
OlivierDehaene committed
967
968
969
970
971
972
973
974
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None

        if cu_seqlen_prefill is not None or cuda_graph is None:
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
            return self.model.forward(
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=cu_seqlen_prefill,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                lm_head_indices=lm_head_indices,
            )

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
        cuda_graph["block_tables"][
            : block_tables.shape[0], : block_tables.shape[1]
        ] = block_tables
        cuda_graph["slots"].fill_(-1)
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths

        # Replay the graph
        cuda_graph["graph"].replay()
        # Slice output to the correct shape
1002
1003
1004
1005
1006
1007
1008
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
1009
1010
1011
1012

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
1013
1014
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
1015
        prefill = batch.cu_seqlen_prefill is not None
1016
        prefill_logprobs = batch.prefill_next_token_indices is not None
1017

1018
1019
        if batch.needed_blocks_slots:
            # Allocate blocks to this batch
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
            block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
                batch.needed_blocks_slots,
                batch.blocks,
                batch.max_blocks,
                batch.input_ids.device,
            )
            batch.needed_blocks_slots = None
            batch.block_tables = block_tables
            batch.block_tables_tensor = block_tables_tensor
            batch.slots = slots
1030

1031
        try:
1032
            out, speculative_logits = self.forward(batch)
1033
1034
1035
        except Exception as e:
            del batch
            raise e
1036

1037
1038
        if prefill:
            next_token_logits = (
1039
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
1040
            )
Nicolas Patry's avatar
Nicolas Patry committed
1041
1042
            if speculative_logits is not None:
                speculative_logits = (
OlivierDehaene's avatar
OlivierDehaene committed
1043
1044
1045
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
Nicolas Patry's avatar
Nicolas Patry committed
1046
                )
1047
1048
1049
        else:
            next_token_logits = out

Nicolas Patry's avatar
Nicolas Patry committed
1050
        speculate = get_speculate()
OlivierDehaene's avatar
OlivierDehaene committed
1051
1052
1053
1054
1055
1056
1057
1058
1059
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen],
            next_token_logits,
Nicolas Patry's avatar
Nicolas Patry committed
1060
            speculate,
OlivierDehaene's avatar
OlivierDehaene committed
1061
1062
            batch.speculative_ids,
            speculative_logits,
1063
1064
        )

Nicolas Patry's avatar
Nicolas Patry committed
1065
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
1066
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1067
1068
        )

1069
        if prefill:
1070
            if len(batch) > 1 and prefill_logprobs:
1071
1072
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
1073
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
1074
1075

            next_position_ids = batch.position_ids.new_empty(len(batch))
1076
1077
1078
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
1079
1080
1081
1082
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

1083
1084
1085
1086
1087
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
1088
        stopped = True
1089
1090

        # Zipped iterator
OlivierDehaene's avatar
OlivierDehaene committed
1091
        iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
1092

1093
1094
1095
1096
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

1097
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1098
        index = 0
OlivierDehaene's avatar
OlivierDehaene committed
1099
        for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
1100
            # Indexing metadata
1101
1102
1103
            start_index = cumulative_length
            end_index = cumulative_length + input_length

1104
            if prefill:
1105
1106
1107
1108
1109
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

1110
1111
1112
1113
1114
1115
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
1116
1117
                if prefill_logprobs:
                    if len(batch) > 1:
drbh's avatar
drbh committed
1118
1119
1120
                        prefill_tokens_indices[out_start_index : out_end_index - 1] = (
                            batch.input_ids[start_index + 1 : start_index + out_length]
                        )
1121
1122
1123
1124
1125
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
1126

Nicolas Patry's avatar
Nicolas Patry committed
1127
1128
1129
            for j in range(n_accepted_ids):
                batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
                index += 1
1130
1131
1132

            cumulative_length += input_length

drbh's avatar
drbh committed
1133
        # Update values
Nicolas Patry's avatar
Nicolas Patry committed
1134
1135
1136
1137
1138
        batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
        batch.speculative_ids = speculative_ids
        batch.position_ids = next_position_ids + accepted_ids
        batch.input_lengths_tensor += accepted_ids
        batch.slot_indices += accepted_ids
1139

1140
        if prefill and prefill_logprobs:
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
1151
        next_token_ids = next_input_ids.tolist()
1152
1153
        accepted_ids = accepted_ids.tolist()
        start_decode = time.time_ns()
1154
1155
1156
1157
1158

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
1159
1160
            batch.prefix_offsets,
            batch.read_offsets,
1161
1162
            batch.stopping_criterias,
            batch.all_input_ids,
1163
1164
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
1165
            batch.top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1166
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1167
1168
            batch_top_token_ids,
            batch_top_token_logprobs,
1169
1170
1171
        )

        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1172
        index = 0
1173
1174
1175
        for i, (
            request,
            input_length,
1176
1177
            prefix_offset,
            read_offset,
1178
1179
            stopping_criteria,
            all_input_ids,
1180
1181
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
1182
            top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1183
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1184
1185
            top_token_ids,
            top_token_logprobs,
1186
        ) in enumerate(iterator):
1187
            # Append next token to all tokens
Nicolas Patry's avatar
Nicolas Patry committed
1188
1189
1190
            next_token_texts = []
            left = 0

1191
1192
1193
1194
            if n_accepted_ids > 1:
                if RANK == 0:
                    logger.debug(f"Speculated ids {n_accepted_ids - 1}")

Nicolas Patry's avatar
Nicolas Patry committed
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
            current_stopped = False
            for j in range(index, index + n_accepted_ids):
                # Generated token
                next_token_id = next_token_ids[j]
                all_input_ids.append(next_token_id)
                next_token_text, prefix_offset, read_offset = self.decode_token(
                    all_input_ids,
                    prefix_offset,
                    read_offset,
                )
                next_token_texts.append(next_token_text)
1206

Nicolas Patry's avatar
Nicolas Patry committed
1207
1208
1209
1210
                stop, reason = stopping_criteria(
                    next_token_id,
                    next_token_text,
                )
1211

Nicolas Patry's avatar
Nicolas Patry committed
1212
1213
1214
1215
1216
1217
1218
                if stop:
                    left = index + n_accepted_ids - j - 1
                    current_stopped = True
                    break
                else:
                    current_stopped = False
            stopped = stopped and current_stopped
1219

OlivierDehaene's avatar
OlivierDehaene committed
1220
1221
1222
1223
            _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
            _next_token_logprobs = next_token_logprobs[
                index : index + n_accepted_ids - left
            ]
Nicolas Patry's avatar
Nicolas Patry committed
1224
            index += n_accepted_ids
1225

1226
1227
1228
1229
1230
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
1231
1232
                    output_text, _, _ = self.decode_token(
                        all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1233
1234
1235
1236
1237
1238
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
1239
1240
                    )
                    generated_text = GeneratedText(
1241
1242
1243
1244
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
1245
1246
1247
1248
1249
                    )
                else:
                    generated_text = None

                # Prefill
1250
1251
1252
1253
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

1254
1255
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
1256
                        out_start_index : out_end_index - 1
1257
1258
1259
1260
1261
1262
1263
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1264
1265

                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
1266
1267
1268
1269
                        prefill_token_ids,
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
1270
1271
1272
1273
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
1274
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
1275
                    all_top_tokens = []
drbh's avatar
drbh committed
1276
                    for top_token_ids, top_token_logprobs in zip(
1277
1278
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
1279
1280
1281
1282
1283
1284
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
1285
1286
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
1287
1288
1289
1290
1291
1292
1293
1294
1295
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
1296
1297
1298
                else:
                    top_tokens = None

1299
1300
1301
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1302
1303
1304
1305
1306
1307
                    Tokens(
                        _next_token_ids,
                        _next_token_logprobs,
                        next_token_texts,
                        [nid in self.all_special_ids for nid in _next_token_ids],
                    ),
1308
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
1309
                    top_tokens,
1310
1311
                )

1312
                generations.append(generation)
1313

drbh's avatar
drbh committed
1314
1315
1316
            # accept each new token for this specific request since we may
            # have more than one new token per request with speculative decoding
            for next_token_id in _next_token_ids:
OlivierDehaene's avatar
OlivierDehaene committed
1317
1318
1319
                batch.next_token_chooser = (
                    batch.next_token_chooser.advance_grammar_single(i, next_token_id)
                )
drbh's avatar
drbh committed
1320

1321
            # Update values
1322
            batch.input_lengths[i] = input_length + n_accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1323
1324
            if batch.input_lengths[i] > batch.max_seqlen:
                batch.max_seqlen = batch.input_lengths[i]
1325
1326
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1327
1328
            batch.all_input_ids[i] = all_input_ids

1329
1330
1331
        if stopped:
            del batch
            # No need to return a batch if we know that all requests stopped
1332
1333
1334
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
1335

1336
1337
1338
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1339

1340
1341
1342
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)