flash_causal_lm.py 52.1 KB
Newer Older
1
import math
2
import os
3
import time
4
import itertools
5
6
7
import torch
import torch.distributed

8
9
import numpy as np

10
from loguru import logger
11
12
from dataclasses import dataclass
from opentelemetry import trace
13
from transformers import PreTrainedTokenizerBase
Daniël de Kok's avatar
Daniël de Kok committed
14
from typing import Iterable, Optional, Tuple, List, Type, Dict
fxmarty's avatar
fxmarty committed
15
16

from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
Daniël de Kok's avatar
Daniël de Kok committed
17
from text_generation_server.utils.chunks import concat_text_chunks
fxmarty's avatar
fxmarty committed
18
from text_generation_server.utils.import_utils import SYSTEM
OlivierDehaene's avatar
OlivierDehaene committed
19
from text_generation_server.models import Model
20
from text_generation_server.utils.tokens import batch_top_tokens
21
from text_generation_server.utils.dist import RANK
Nicolas Patry's avatar
Nicolas Patry committed
22
from text_generation_server.utils.speculate import get_speculate
23
24
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
25
    Tokens,
26
27
28
29
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
30
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
fxmarty's avatar
fxmarty committed
31
import text_generation_server.models.globals as tgi_globals
32
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
33
from text_generation_server.utils.dist import MEMORY_FRACTION
34

Nicolas Patry's avatar
Nicolas Patry committed
35
from text_generation_server.utils.import_utils import (
Nicolas Patry's avatar
Nicolas Patry committed
36
37
38
    empty_cache,
    synchronize,
    get_free_memory,
Nicolas Patry's avatar
Nicolas Patry committed
39
40
)

Nicolas Patry's avatar
Nicolas Patry committed
41
42
tracer = trace.get_tracer(__name__)

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
BLOCK_SIZE: int = 16

# Will be set in init
SLIDING_WINDOW: Optional[int] = None


def set_sliding_window(sliding_window: int):
    global SLIDING_WINDOW
    SLIDING_WINDOW = sliding_window


def get_sliding_windows() -> int:
    global SLIDING_WINDOW
    return SLIDING_WINDOW

58

59
60
61
62
@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
63
64
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
65
66

    # Decoder values
67
68
    input_ids: torch.Tensor
    position_ids: torch.Tensor
69
    speculative_ids: Optional[torch.Tensor]
70

71
72
73
74
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
75
76
77
    # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
    # as we only keep SLIDING_WINDOW values instead of the whole tensor
    prefill_cache_indices: Optional[torch.Tensor]
78
79
80
81
82
83
84
85
86
87

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor

    # list of length b of list of length s_i // block_size
88
    block_tables: List[List[int]]
89
    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
90
    block_tables_tensor: torch.Tensor
91
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
92
    slots: torch.Tensor
93

94
95
    max_seqlen: int

96
97
98
99
100
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

101
102
    # All tokens
    all_input_ids: List[List[int]]
103
    all_input_ids_tensor: torch.Tensor
104
105
106

    # Lengths of all generations present in the batch
    input_lengths: List[int]
107
    input_lengths_tensor: torch.Tensor
108
109
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
110
111

    # Generation helpers
112
    next_token_chooser: HeterogeneousNextTokenChooser
113
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
114
115
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
116

117
    # Number of blocks in this batch
118
    num_blocks: int
119
120
    # Maximum number of blocks
    max_blocks: int
121

122
123
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
124
            id=self.batch_id,
125
            request_ids=[r.id for r in self.requests],
126
            size=len(self),
127
            max_tokens=self.num_blocks * BLOCK_SIZE,
128
129
130
        )

    @classmethod
Daniël de Kok's avatar
Daniël de Kok committed
131
132
133
    def batch_tokenized_inputs(
        cls, requests: Iterable[generate_pb2.Request], tokenizer
    ):
134
135
        batch_inputs = []
        max_truncation = 0
136
        for r in requests:
Daniël de Kok's avatar
Daniël de Kok committed
137
            batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
138
139
140
141
142
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]
143
        return batch_tokenized_inputs
144

drbh's avatar
drbh committed
145
146
147
148
149
150
151
152
153
    @classmethod
    def from_tokenized(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        batch_tokenized_inputs,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
154
        sliding_window = get_sliding_windows()
155
        position_ids = []
156
        cu_seqlen_prefill = [0]
157
158
        start_slots = []
        slot_indices = []
159
        prefill_cache_indices = []
160
161

        input_lengths = []
162
163
        prefix_offsets = []
        read_offsets = []
164
        all_input_ids = []
165
        requests_idx_mapping = {}
166

167
168
169
170
171
172
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

173
        next_token_chooser_parameters = []
174
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
175
        top_n_tokens = []
176
177
178

        # Cumulative length
        cumulative_length = 0
179
        cumulative_max_length = 0
180
        prefill_out_cumulative_length = 0
181

182
        num_blocks = 0
183
        max_seqlen = 0
184
        max_length = 0
185
        max_blocks = 0
186

187
188
189
        block_tables = []
        slots = []

190
        # Parse batch
191
192
193
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
194
195
196
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

197
            tokenized_input = tokenized_input[-r.truncate :]
198
199
200
201
202
            if (
                tokenized_input[0] == tokenizer.bos_token_id
                and tokenized_input[1] == tokenizer.bos_token_id
            ):
                tokenized_input = tokenized_input[1:]
203

204
205
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
206

207
            prefix_offsets.append(input_length - 5)
208
            read_offsets.append(input_length)
209

210
            all_input_ids.append(tokenized_input)
211
212

            # Position ids
213
214
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
215
216

            # Add cumulative lengths of all previous inputs
217
            cu_seqlen_prefill.append(cumulative_length + input_length)
218

219
            next_token_chooser_parameters.append(r.parameters)
220

221
222
223
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
224
            max_new_tokens = stopping_criteria.max_new_tokens
225
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
226
            top_n_tokens.append(r.top_n_tokens)
227

228
229
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
230
            speculative_length = get_speculate()
drbh's avatar
drbh committed
231
            speculative_length = 0 if speculative_length is None else speculative_length
Nicolas Patry's avatar
Nicolas Patry committed
232
            total_tokens = input_length + max_new_tokens - 1 + speculative_length
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251

            # blocks and slots can be empty (for example in warmup)
            if not r.blocks:
                needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
                request_blocks = [
                    b for b in range(num_blocks, num_blocks + needed_blocks)
                ]
                request_slots = [
                    s
                    for b in request_blocks
                    for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
                ]
            else:
                request_blocks = r.blocks
                request_slots = r.slots

            block_tables.append(request_blocks)
            slots.extend(request_slots[:total_tokens])
            num_blocks += len(request_blocks)
252
253
254
255
256
257
258
259
260
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

261
262
263
264
265
266
267
268
269
            # Create tensor to slice into the kv tensor in prefill
            if sliding_window is not None:
                request_prefill_cache_indices = torch.arange(
                    cumulative_length + max(0, input_length - sliding_window),
                    cumulative_length + input_length,
                    dtype=torch.int64,
                )
                prefill_cache_indices.append(request_prefill_cache_indices)

270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

290
291
            # Update
            cumulative_length += input_length
292
293
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
294
            max_blocks = max(max_blocks, len(request_blocks))
OlivierDehaene's avatar
OlivierDehaene committed
295
296
297
            max_length = max(
                max_length, input_length + max_new_tokens + speculative_length
            )
298
299

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
300
            next_token_chooser_parameters, dtype, device, tokenizer
301
        )
302
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
303
304
305
306
307
308
309

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
310

311
312
313
314
315
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

316
317
318
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
319
            slot_indices = torch.cat(slot_indices)
320
321
            if sliding_window is not None:
                prefill_cache_indices = torch.cat(prefill_cache_indices)
322
323
324
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
325
            slot_indices = slot_indices[0]
326
327
            if sliding_window is not None:
                prefill_cache_indices = prefill_cache_indices[0]
328

329
330
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
331
332
333
        )
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
334
335
336
        prefill_cache_indices = (
            prefill_cache_indices.to(device) if sliding_window is not None else None
        )
337
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
338
339
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
340
        )
341

342
343
        if all_prefill_logprobs:
            prefill_head_indices = None
344
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
345
        elif no_prefill_logprobs:
346
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
347
348
349
350
351
352
353
354
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
Nicolas Patry's avatar
Nicolas Patry committed
355
356
357
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
358

359
360
361
362
363
364
365
366
        slots = torch.tensor(slots, dtype=torch.int64, device=device)
        block_tables_tensor = torch.zeros(
            (len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
        )
        for i, request_blocks in enumerate(block_tables):
            block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
        block_tables_tensor = block_tables_tensor.to(device)

367
368
369
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
370
            requests_idx_mapping=requests_idx_mapping,
371
372
            input_ids=input_ids,
            position_ids=position_ids,
373
            cu_seqlen_prefill=cu_seqlen_prefill,
374
            prefill_cache_indices=prefill_cache_indices,
375
376
            start_slots=start_slots,
            slot_indices=slot_indices,
377
378
379
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
380
            max_seqlen=max_seqlen,
381
382
383
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
384
            input_lengths=input_lengths,
385
            input_lengths_tensor=input_lengths_tensor,
386
387
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
388
            all_input_ids=all_input_ids,
389
390
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
391
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
392
393
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
394
            num_blocks=num_blocks,
395
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
396
            speculative_ids=None,
397
398
        )

399
400
401
402
403
404
405
406
407
408
409
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
        return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)

410
    @tracer.start_as_current_span("filter")
411
412
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
413
414
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
415
        if len(request_ids) == len(self):
416
417
            return self

418
        device = self.input_ids.device
419

420
421
422
        # New values after filtering
        requests_idx_mapping = {}

423
424
425
        # Used to index into tensors
        indices = []

426
427
428
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
429
430
        )

431
        # Create on CPU to only move to GPU once instead of at every copy
432
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
433
434
        max_seqlen = 0

435
        requests = []
436
437
        start_slots = []
        block_tables = []
438
439
        all_input_ids = []

440
        input_lengths = []
441
442
        prefix_offsets = []
        read_offsets = []
443

444
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
445
        top_n_tokens = []
446

447
        num_blocks = 0
448
449
450
451
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

452
453
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
454
            indices.append(idx)
455
456
457
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
458
459
460
461

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
462

463
464
465
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
466
467
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
468

469
470
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
471

Nicolas Patry's avatar
Nicolas Patry committed
472
473
            top_n_tokens.append(self.top_n_tokens[idx])

474
            remaining_tokens = (
475
476
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
477

478
            request_block_table = self.block_tables[idx]
479
            num_blocks += len(request_block_table)
480
481
482
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

483
            # Copy to tensor (CPU)
484
            slot_indices[i] = cumulative_max_length + request_input_length - 1
485
486

            # Set slice
487
488
489
490
491
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
492
493
494
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
495

496
497
            max_blocks = max(max_blocks, len(request_block_table))

498
499
500
501
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
502
503
504
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
505
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
506
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
OlivierDehaene's avatar
OlivierDehaene committed
507
508
509
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
510
511

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
512

513
        # Move to GPU now that we have the whole tensor
514
        slot_indices = slot_indices.to(device)
515

516
        return type(self)(
517
518
519
520
521
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
522
            cu_seqlen_prefill=None,
523
            prefill_cache_indices=None,
524
525
526
527
528
            start_slots=start_slots,
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
529
            max_seqlen=max_seqlen,
530
531
532
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
533
            input_lengths=input_lengths,
534
            input_lengths_tensor=input_lengths_tensor,
535
536
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
537
538
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
539
            next_token_chooser=next_token_chooser,
540
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
541
542
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
543
            num_blocks=num_blocks,
544
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
545
            speculative_ids=speculative_ids,
546
547
548
549
550
551
552
553
554
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

555
        num_blocks = 0
556
557
558
559
560
561
562
563
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
564
            num_blocks += b.num_blocks
OlivierDehaene's avatar
OlivierDehaene committed
565
566
567
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
568
569
570
571
572
573
574
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
575
                    + speculative_length
576
577
578
579
580
581
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
582
583
584

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
585
586
587
588
589
590
591
592
593
594
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
595
        )
Nicolas Patry's avatar
Nicolas Patry committed
596
597
598
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
599

600
601
        start_slots = []
        block_tables = []
602
603
604
        all_input_ids = []

        input_lengths = []
605
606
        prefix_offsets = []
        read_offsets = []
607

608
        next_token_chooser_parameters = []
609
        fsm_grammar_states = []
610
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
611
        top_n_tokens = []
612

613
        # Cumulative length
614
        cumulative_batch_size = 0
615
        cumulative_slots = 0
616
617
618

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
619
620
621
622
623
624
625
626

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

627
628
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
629
630
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
631
632
633
634

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
635
636
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
637
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
638
            slots[slots_start_index:slots_end_index] = batch.slots
639

640
641
642
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
643

644
645
646
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
647

648
649
650
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
651
652
            all_input_ids.extend(batch.all_input_ids)

653
            input_lengths.extend(batch.input_lengths)
654
655
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
656

657
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
658
            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
659
660
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
661
662
            top_n_tokens.extend(batch.top_n_tokens)

663
            # Update
664
            cumulative_batch_size += len(batch)
665
            cumulative_slots += len(batch.slots)
666

667
        start_slots = torch.concat(start_slots)
668

669
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
670
671
672
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
drbh's avatar
drbh committed
673
            tokenizer=batches[0].next_token_chooser.tokenizer,
674
            fsm_grammar_states=fsm_grammar_states,
675
676
        )

OlivierDehaene's avatar
OlivierDehaene committed
677
678
679
680
681
        speculative_ids = (
            torch.cat([b.speculative_ids for b in batches], dim=0)
            if batches[0].speculative_ids is not None
            else None
        )
Nicolas Patry's avatar
Nicolas Patry committed
682

683
        return cls(
684
685
            batch_id=batches[0].batch_id,
            requests=requests,
686
            requests_idx_mapping=requests_idx_mapping,
687
688
            input_ids=input_ids,
            position_ids=position_ids,
689
            cu_seqlen_prefill=None,
690
            prefill_cache_indices=None,
691
692
693
694
695
            start_slots=start_slots,
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
696
            max_seqlen=max_seqlen,
697
698
699
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
700
            input_lengths=input_lengths,
701
            input_lengths_tensor=input_lengths_tensor,
702
703
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
704
705
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
706
            next_token_chooser=next_token_chooser,
707
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
708
709
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
710
            num_blocks=num_blocks,
711
            max_blocks=max_blocks,
OlivierDehaene's avatar
OlivierDehaene committed
712
            speculative_ids=speculative_ids,
713
714
715
716
717
718
719
720
721
        )

    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
722
723
724
725
726
727
728
729
730
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        num_layers: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
731
        sliding_window: Optional[int] = None,
732
    ):
733
734
735
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
736

737
        self.cuda_graphs = {}
738
        self.kv_cache = []
739

740
        super(FlashCausalLM, self).__init__(
741
            model=model,
742
743
744
745
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
746
747
            rank=rank,
            world_size=world_size,
748
            sliding_window=sliding_window,
749
750
751
752
753
754
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
    def max_past(self) -> int:
        return getattr(self.model, "max_past", None)

    def init_kv_cache(
        self,
        num_blocks: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        self.kv_cache = []
        empty_cache()

        element_size = torch.tensor([], dtype=dtype).element_size()
        if SYSTEM == "xpu":
            x = 1
        else:
            x = BLOCK_SIZE // element_size

        self.kv_cache = [
            (
                torch.empty(
                    (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
                    dtype=dtype,
                    device=device,
                ),
                torch.empty(
                    (num_blocks, num_heads, head_size, BLOCK_SIZE),
                    dtype=dtype,
                    device=device,
                ),
            )
            for _ in range(num_layers)
        ]

792
793
794
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
795
        slots = torch.arange(bs, dtype=torch.int64, device=self.device)
796
797
798
799
800
801
802
803
804
805
        input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
        block_tables = (
            torch.arange(max_bt, dtype=torch.int32, device=self.device)
            .repeat(bs)
            .reshape((bs, max_bt))
        )

        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
806
            "kv_cache": self.kv_cache,
807
808
809
810
811
812
813
814
815
816
817
818
819
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths,
        }
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs]["graph"] = graph

        torch.cuda.synchronize()
        # Run once outside to warmup
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=None,
820
            kv_cache=self.kv_cache,
821
822
823
824
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
            max_s=max_s,
825
            prefill_cache_indices=None,
826
827
828
829
830
            lm_head_indices=None,
        )
        torch.cuda.synchronize()

        with torch.cuda.graph(graph, pool=MEM_POOL):
831
            logits, speculative_logits = self.model.forward(
832
833
834
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
835
                kv_cache=self.kv_cache,
836
837
838
839
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
840
                prefill_cache_indices=None,
841
842
                lm_head_indices=None,
            )
843
844
            self.cuda_graphs[bs]["logits"] = logits
            self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
845
846
        torch.cuda.synchronize()

847
    def warmup(self, batch: FlashCausalLMBatch):
848
        # The warmup batch is the biggest batch we could ever receive
Nicolas Patry's avatar
Nicolas Patry committed
849
850
        empty_cache()

851
        try:
852
853
            self.init_kv_cache(
                batch.num_blocks,
854
855
856
857
858
859
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
                self.dtype,
                self.device,
            )
860
            max_bt = batch.max_blocks
861
            max_s = max_bt * BLOCK_SIZE
fxmarty's avatar
fxmarty committed
862
863
864

            if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
                torch.cuda.tunable.tuning_enable(False)
865
            _, batch, _ = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
866
        except torch.cuda.OutOfMemoryError as e:
867
            raise RuntimeError(
868
869
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
870
            ) from e
871

Nicolas Patry's avatar
Nicolas Patry committed
872
        synchronize(self.device)
873

874
875
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
876
877
878
879
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

Nicolas Patry's avatar
Nicolas Patry committed
880
        free_memory = get_free_memory(self.device, MEMORY_FRACTION)
881
882

        num_blocks = (
883
884
            # Leave 5% for some wiggle room
            int((free_memory * 0.95) // total_cache_size)
885
886
            # Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
            + batch.num_blocks
887
888
        )

889
        del batch
890

891
        self.init_kv_cache(
892
893
894
895
896
897
898
899
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
            self.dtype,
            self.device,
        )

fxmarty's avatar
fxmarty committed
900
901
902
903
904
        if SYSTEM == "rocm":
            if (
                os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
                or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
            ):
905
906
                torch.cuda.tunable.enable()

fxmarty's avatar
fxmarty committed
907
908
909
910
911
912
913
914
                if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
                    torch.cuda.tunable.tuning_enable(True)

                if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None:
                    tuning_sequences = [
                        int(val)
                        for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
                    ]
915
                elif CUDA_GRAPHS is not None:
fxmarty's avatar
fxmarty committed
916
                    tuning_sequences = CUDA_GRAPHS
917
918
919
                else:
                    # For seqlen = 1, we dispatch to LLMM1 kernel.
                    tuning_sequences = [2, 3, 4, 5, 6, 7]
fxmarty's avatar
fxmarty committed
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947

                tunableop_filepath = os.path.join(
                    HUGGINGFACE_HUB_CACHE,
                    f"tunableop_{tgi_globals.MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
                )

                logger.info(
                    f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`."
                )

                if os.path.isfile(tunableop_filepath):
                    logger.info(
                        f"The file {tunableop_filepath} already exists and will be reused."
                    )
                    torch.cuda.tunable.read_file(tunableop_filepath)

                os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)

                for seqlen in tuning_sequences:
                    logger.info(f"Warming up TunableOp for seqlen={seqlen}")
                    self.tunableop_warmup(seqlen)
                    torch.cuda.tunable.write_file(tunableop_filepath)
                torch.cuda.tunable.tuning_enable(False)
            else:
                logger.info(
                    "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp."
                )

948
        if CUDA_GRAPHS:
949
            try:
950
                logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
951
                # Warmup cuda graphs
952
                for bs in CUDA_GRAPHS:
953
954
                    if self.speculate is None or self.speculate + 1 <= bs:
                        self.cuda_graph_warmup(bs, max_s, max_bt)
OlivierDehaene's avatar
OlivierDehaene committed
955
            except torch.cuda.OutOfMemoryError:
956
                logger.exception(f"Decode cuda graph warmup failed")
957
958
        else:
            logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
959

960
        return int(num_blocks * BLOCK_SIZE)
961

fxmarty's avatar
fxmarty committed
962
963
964
965
966
    def tunableop_warmup(self, seqlen: int):
        input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
        slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)

fxmarty's avatar
fxmarty committed
967
968
969
        # Dummy value, some models (starcoder2) don't accept `None`.
        input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)

fxmarty's avatar
fxmarty committed
970
971
972
973
974
975
976
        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=torch.tensor(
                [0, seqlen], device=self.device, dtype=torch.int32
            ),
977
            kv_cache=self.kv_cache,
fxmarty's avatar
fxmarty committed
978
            block_tables=None,
fxmarty's avatar
fxmarty committed
979
            input_lengths=input_lengths,
fxmarty's avatar
fxmarty committed
980
981
982
            slots=slots,
            max_s=seqlen,
            lm_head_indices=None,
983
            prefill_cache_indices=None,
fxmarty's avatar
fxmarty committed
984
985
        )

986
987
988
    def forward(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
989
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
990
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
991
992
993
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
994
            kv_cache = self.kv_cache
OlivierDehaene's avatar
OlivierDehaene committed
995
996
997
998
999
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
1000
1001
1002

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
1003
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
1004
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
1005
1006
1007
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1008
1009
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
1010
1011
1012
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1013
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
OlivierDehaene's avatar
OlivierDehaene committed
1014
1015
1016
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1017
1018

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
1019
1020
1021
1022
1023
1024
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
1025
1026
1027
1028
1029
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
1030
1031
1032
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
1033
            kv_cache = self.kv_cache
OlivierDehaene's avatar
OlivierDehaene committed
1034
1035
1036
1037
1038
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
1039

1040
1041
1042
1043
1044
1045
        if cu_seqlen_prefill is None and self.max_past() is not None:
            # In decode, not prefill, we're actually overwriting the KV-cache
            # in a circular buffer mode.
            # This makes sure the max_s for the decode pass is correct.
            max_s = min(self.max_past(), max_s)

1046
        bs = input_ids.shape[0]
OlivierDehaene's avatar
OlivierDehaene committed
1047
1048
1049
1050
1051
1052
1053
1054
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None

        if cu_seqlen_prefill is not None or cuda_graph is None:
1055
            logits, speculative_logits = self.model.forward(
1056
1057
1058
1059
1060
1061
1062
1063
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=cu_seqlen_prefill,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
1064
                prefill_cache_indices=batch.prefill_cache_indices,
1065
1066
                lm_head_indices=lm_head_indices,
            )
1067
1068
1069
            if batch.prefill_cache_indices is not None:
                batch.prefill_cache_indices = None
            return logits, speculative_logits
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
        cuda_graph["block_tables"][
            : block_tables.shape[0], : block_tables.shape[1]
        ] = block_tables
        cuda_graph["slots"].fill_(-1)
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths

        # Replay the graph
        cuda_graph["graph"].replay()
        # Slice output to the correct shape
1086
1087
1088
1089
1090
1091
1092
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
1093
1094
1095
1096

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
1097
1098
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
1099
        prefill = batch.cu_seqlen_prefill is not None
1100
        prefill_logprobs = batch.prefill_next_token_indices is not None
1101

1102
        out, speculative_logits = self.forward(batch)
1103

1104
1105
        if prefill:
            next_token_logits = (
1106
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
1107
            )
Nicolas Patry's avatar
Nicolas Patry committed
1108
1109
            if speculative_logits is not None:
                speculative_logits = (
OlivierDehaene's avatar
OlivierDehaene committed
1110
1111
1112
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
Nicolas Patry's avatar
Nicolas Patry committed
1113
                )
1114
1115
1116
        else:
            next_token_logits = out

Nicolas Patry's avatar
Nicolas Patry committed
1117
        speculate = get_speculate()
OlivierDehaene's avatar
OlivierDehaene committed
1118
1119
1120
1121
1122
1123
1124
1125
1126
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen],
            next_token_logits,
Nicolas Patry's avatar
Nicolas Patry committed
1127
            speculate,
OlivierDehaene's avatar
OlivierDehaene committed
1128
1129
            batch.speculative_ids,
            speculative_logits,
1130
1131
        )

Nicolas Patry's avatar
Nicolas Patry committed
1132
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
1133
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1134
1135
        )

1136
        if prefill:
1137
            if len(batch) > 1 and prefill_logprobs:
1138
1139
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
1140
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
1141
1142

            next_position_ids = batch.position_ids.new_empty(len(batch))
1143
1144
1145
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
1146
1147
1148
1149
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

1150
1151
1152
1153
1154
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
1155
        stopped = True
1156
1157

        # Zipped iterator
OlivierDehaene's avatar
OlivierDehaene committed
1158
        iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
1159

1160
1161
1162
1163
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

1164
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1165
        index = 0
OlivierDehaene's avatar
OlivierDehaene committed
1166
        for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
1167
            # Indexing metadata
1168
1169
1170
            start_index = cumulative_length
            end_index = cumulative_length + input_length

1171
            if prefill:
1172
1173
1174
1175
1176
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

1177
1178
1179
1180
1181
1182
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
1183
1184
                if prefill_logprobs:
                    if len(batch) > 1:
drbh's avatar
drbh committed
1185
1186
1187
                        prefill_tokens_indices[out_start_index : out_end_index - 1] = (
                            batch.input_ids[start_index + 1 : start_index + out_length]
                        )
1188
1189
1190
1191
1192
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
1193

Nicolas Patry's avatar
Nicolas Patry committed
1194
1195
1196
            for j in range(n_accepted_ids):
                batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
                index += 1
1197
1198
1199

            cumulative_length += input_length

drbh's avatar
drbh committed
1200
        # Update values
Nicolas Patry's avatar
Nicolas Patry committed
1201
1202
1203
1204
1205
        batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
        batch.speculative_ids = speculative_ids
        batch.position_ids = next_position_ids + accepted_ids
        batch.input_lengths_tensor += accepted_ids
        batch.slot_indices += accepted_ids
1206

1207
        if prefill and prefill_logprobs:
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
1218
        next_token_ids = next_input_ids.tolist()
1219
1220
        accepted_ids = accepted_ids.tolist()
        start_decode = time.time_ns()
1221
1222
1223
1224
1225

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
1226
1227
            batch.prefix_offsets,
            batch.read_offsets,
1228
1229
            batch.stopping_criterias,
            batch.all_input_ids,
1230
1231
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
1232
            batch.top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1233
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1234
1235
            batch_top_token_ids,
            batch_top_token_logprobs,
1236
1237
1238
        )

        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1239
        index = 0
1240
1241
1242
        for i, (
            request,
            input_length,
1243
1244
            prefix_offset,
            read_offset,
1245
1246
            stopping_criteria,
            all_input_ids,
1247
1248
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
1249
            top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1250
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1251
1252
            top_token_ids,
            top_token_logprobs,
1253
        ) in enumerate(iterator):
1254
            # Append next token to all tokens
Nicolas Patry's avatar
Nicolas Patry committed
1255
1256
1257
            next_token_texts = []
            left = 0

1258
1259
1260
1261
            if n_accepted_ids > 1:
                if RANK == 0:
                    logger.debug(f"Speculated ids {n_accepted_ids - 1}")

Nicolas Patry's avatar
Nicolas Patry committed
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
            current_stopped = False
            for j in range(index, index + n_accepted_ids):
                # Generated token
                next_token_id = next_token_ids[j]
                all_input_ids.append(next_token_id)
                next_token_text, prefix_offset, read_offset = self.decode_token(
                    all_input_ids,
                    prefix_offset,
                    read_offset,
                )
                next_token_texts.append(next_token_text)
1273

Nicolas Patry's avatar
Nicolas Patry committed
1274
1275
1276
1277
                stop, reason = stopping_criteria(
                    next_token_id,
                    next_token_text,
                )
1278

Nicolas Patry's avatar
Nicolas Patry committed
1279
1280
1281
1282
1283
1284
1285
                if stop:
                    left = index + n_accepted_ids - j - 1
                    current_stopped = True
                    break
                else:
                    current_stopped = False
            stopped = stopped and current_stopped
1286

OlivierDehaene's avatar
OlivierDehaene committed
1287
1288
1289
1290
            _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
            _next_token_logprobs = next_token_logprobs[
                index : index + n_accepted_ids - left
            ]
Nicolas Patry's avatar
Nicolas Patry committed
1291
            index += n_accepted_ids
1292

1293
1294
1295
1296
1297
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
1298
1299
                    output_text, _, _ = self.decode_token(
                        all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1300
1301
1302
1303
1304
1305
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
1306
1307
                    )
                    generated_text = GeneratedText(
1308
1309
1310
1311
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
1312
1313
1314
1315
1316
                    )
                else:
                    generated_text = None

                # Prefill
1317
1318
1319
1320
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

1321
1322
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
1323
                        out_start_index : out_end_index - 1
1324
1325
1326
1327
1328
1329
1330
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1331
1332

                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
1333
1334
1335
1336
                        prefill_token_ids,
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
1337
1338
1339
1340
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
1341
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
1342
                    all_top_tokens = []
drbh's avatar
drbh committed
1343
                    for top_token_ids, top_token_logprobs in zip(
1344
1345
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
1346
1347
1348
1349
1350
1351
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
1352
1353
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
1354
1355
1356
1357
1358
1359
1360
1361
1362
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
1363
1364
1365
                else:
                    top_tokens = None

1366
1367
1368
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1369
1370
1371
1372
1373
1374
                    Tokens(
                        _next_token_ids,
                        _next_token_logprobs,
                        next_token_texts,
                        [nid in self.all_special_ids for nid in _next_token_ids],
                    ),
1375
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
1376
                    top_tokens,
1377
1378
                )

1379
                generations.append(generation)
1380

drbh's avatar
drbh committed
1381
1382
1383
            # accept each new token for this specific request since we may
            # have more than one new token per request with speculative decoding
            for next_token_id in _next_token_ids:
OlivierDehaene's avatar
OlivierDehaene committed
1384
1385
1386
                batch.next_token_chooser = (
                    batch.next_token_chooser.advance_grammar_single(i, next_token_id)
                )
drbh's avatar
drbh committed
1387

1388
            # Update values
1389
            batch.input_lengths[i] = input_length + n_accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1390
1391
            if batch.input_lengths[i] > batch.max_seqlen:
                batch.max_seqlen = batch.input_lengths[i]
1392
1393
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1394
1395
            batch.all_input_ids[i] = all_input_ids

1396
1397
        if stopped:
            # No need to return a batch if we know that all requests stopped
1398
1399
1400
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
1401

1402
1403
1404
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1405

1406
1407
1408
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)