flash_causal_lm.py 49.5 KB
Newer Older
1
import math
2
import os
3
import time
4
import itertools
5
6
7
import torch
import torch.distributed

8
9
import numpy as np

10
from loguru import logger
11
12
from dataclasses import dataclass
from opentelemetry import trace
13
from transformers import PreTrainedTokenizerBase
14
from typing import Optional, Tuple, List, Type, Dict
fxmarty's avatar
fxmarty committed
15
16
17

from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.import_utils import SYSTEM
OlivierDehaene's avatar
OlivierDehaene committed
18
from text_generation_server.models import Model
19
from text_generation_server.utils.tokens import batch_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
20
from text_generation_server.utils.speculate import get_speculate
21
22
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
23
    Tokens,
24
25
26
    Generation,
    GeneratedText,
)
27
28
29
30
31
from text_generation_server.models.cache_manager import (
    get_cache_manager,
    set_cache_manager,
    BLOCK_SIZE,
)
32
from text_generation_server.pb import generate_pb2
33
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
fxmarty's avatar
fxmarty committed
34
import text_generation_server.models.globals as tgi_globals
35
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
36
from text_generation_server.utils.dist import MEMORY_FRACTION
37

Nicolas Patry's avatar
Nicolas Patry committed
38
from text_generation_server.utils.import_utils import (
Nicolas Patry's avatar
Nicolas Patry committed
39
40
41
    empty_cache,
    synchronize,
    get_free_memory,
Nicolas Patry's avatar
Nicolas Patry committed
42
43
)

Nicolas Patry's avatar
Nicolas Patry committed
44
45
tracer = trace.get_tracer(__name__)

46

47
48
49
50
@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
51
52
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
53
54

    # Decoder values
55
56
    input_ids: torch.Tensor
    position_ids: torch.Tensor
Nicolas Patry's avatar
Nicolas Patry committed
57
    speculative_ids: torch.Tensor
58

59
60
61
62
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
63
64
65
66
67
68
69
70
71
72
73
74
75
76

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor
    # List of tuple of ints representing the number of blocks and slots needed by each sequence
    needed_blocks_slots: Optional[List[Tuple[int, int]]]

    # Set in prefill by the CacheManager
    # list of length b of list of length s_i // block_size
    block_tables: Optional[List[List[int]]]
77
    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
78
79
80
81
    block_tables_tensor: Optional[torch.Tensor]
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
    slots: Optional[torch.Tensor]

82
83
    max_seqlen: int

84
85
86
87
88
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

89
90
    # All tokens
    all_input_ids: List[List[int]]
91
    all_input_ids_tensor: torch.Tensor
92
93
94

    # Lengths of all generations present in the batch
    input_lengths: List[int]
95
    input_lengths_tensor: torch.Tensor
96
97
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
98
99

    # Generation helpers
100
    next_token_chooser: HeterogeneousNextTokenChooser
101
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
102
103
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
104

105
106
107
108
    # Number of blocks in this batch
    blocks: int
    # Maximum number of blocks
    max_blocks: int
109

110
111
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
112
            id=self.batch_id,
113
            request_ids=[r.id for r in self.requests],
114
            size=len(self),
115
            max_tokens=self.blocks * BLOCK_SIZE,
116
117
118
        )

    @classmethod
119
    def batch_tokenized_inputs(cls, requests, tokenizer):
120
121
        batch_inputs = []
        max_truncation = 0
122
        for r in requests:
123
124
125
126
127
128
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]
129
        return batch_tokenized_inputs
130

131
132
133
134
135
136
137
138
139
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
drbh's avatar
drbh committed
140
141
142
143
144
145
146
147
148
149
150
        return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)

    @classmethod
    def from_tokenized(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        batch_tokenized_inputs,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
151
        position_ids = []
Nicolas Patry's avatar
Nicolas Patry committed
152
        speculative_ids = []
153
        cu_seqlen_prefill = [0]
154
155
156
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
157
158

        input_lengths = []
159
160
        prefix_offsets = []
        read_offsets = []
161
        all_input_ids = []
162
        requests_idx_mapping = {}
163

164
165
166
167
168
169
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

170
        next_token_chooser_parameters = []
171
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
172
        top_n_tokens = []
173
174
175

        # Cumulative length
        cumulative_length = 0
176
        cumulative_max_length = 0
177
        prefill_out_cumulative_length = 0
178

179
180
        blocks = 0
        max_seqlen = 0
181
        max_length = 0
182
        max_blocks = 0
183

184
        # Parse batch
185
186
187
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
188
189
190
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

191
            tokenized_input = tokenized_input[-r.truncate :]
192
193
194
195
196
            if (
                tokenized_input[0] == tokenizer.bos_token_id
                and tokenized_input[1] == tokenizer.bos_token_id
            ):
                tokenized_input = tokenized_input[1:]
197

198
199
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
200

201
            prefix_offsets.append(input_length - 5)
202
            read_offsets.append(input_length)
203

204
            all_input_ids.append(tokenized_input)
205
206

            # Position ids
207
208
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
209
210

            # Add cumulative lengths of all previous inputs
211
            cu_seqlen_prefill.append(cumulative_length + input_length)
212

213
            next_token_chooser_parameters.append(r.parameters)
214

215
216
217
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
218
            max_new_tokens = stopping_criteria.max_new_tokens
219
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
220
            top_n_tokens.append(r.top_n_tokens)
221

222
223
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
224
            speculative_length = get_speculate()
drbh's avatar
drbh committed
225
            speculative_length = 0 if speculative_length is None else speculative_length
Nicolas Patry's avatar
Nicolas Patry committed
226
            total_tokens = input_length + max_new_tokens - 1 + speculative_length
227
228
229
230
231
232
233
234
235
236
237
238
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
            blocks += needed_blocks
            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

259
260
            # Update
            cumulative_length += input_length
261
262
263
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
OlivierDehaene's avatar
OlivierDehaene committed
264
265
266
            max_length = max(
                max_length, input_length + max_new_tokens + speculative_length
            )
267
268

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
269
            next_token_chooser_parameters, dtype, device, tokenizer
270
        )
271
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
272
273
274
275
276
277
278

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
279

280
281
282
283
284
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

285
286
287
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
288
            slot_indices = torch.cat(slot_indices)
289
290
291
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
292
            slot_indices = slot_indices[0]
293

294
295
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
296
297
298
        )
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
299
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
300
301
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
302
        )
303

304
305
        if all_prefill_logprobs:
            prefill_head_indices = None
306
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
307
        elif no_prefill_logprobs:
308
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
309
310
311
312
313
314
315
316
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
Nicolas Patry's avatar
Nicolas Patry committed
317
318
319
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
320

321
322
323
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
324
            requests_idx_mapping=requests_idx_mapping,
325
326
            input_ids=input_ids,
            position_ids=position_ids,
327
            cu_seqlen_prefill=cu_seqlen_prefill,
328
329
330
331
332
333
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
334
            max_seqlen=max_seqlen,
335
336
337
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
338
            input_lengths=input_lengths,
339
            input_lengths_tensor=input_lengths_tensor,
340
341
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
342
            all_input_ids=all_input_ids,
343
344
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
345
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
346
347
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
348
349
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
350
            speculative_ids=None,
351
352
        )

353
    @tracer.start_as_current_span("filter")
354
355
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
356
357
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
358
        if len(request_ids) == len(self):
359
360
            return self

361
        device = self.input_ids.device
362

363
364
365
        # New values after filtering
        requests_idx_mapping = {}

366
367
368
        # Used to index into tensors
        indices = []

369
370
371
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
372
373
        )

374
        # Create on CPU to only move to GPU once instead of at every copy
375
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
376
377
        max_seqlen = 0

378
        requests = []
379
380
        start_slots = []
        block_tables = []
381
382
        all_input_ids = []

383
        input_lengths = []
384
385
        prefix_offsets = []
        read_offsets = []
386

387
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
388
        top_n_tokens = []
389

390
391
392
393
394
        blocks = 0
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

395
396
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
397
            indices.append(idx)
398
399
400
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
401
402
403
404

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
405

406
407
408
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
409
410
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
411

412
413
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
414

Nicolas Patry's avatar
Nicolas Patry committed
415
416
            top_n_tokens.append(self.top_n_tokens[idx])

417
            remaining_tokens = (
418
419
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
420

421
422
423
424
425
            request_block_table = self.block_tables[idx]
            blocks += len(request_block_table)
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

426
            # Copy to tensor (CPU)
427
            slot_indices[i] = cumulative_max_length + request_input_length - 1
428
429

            # Set slice
430
431
432
433
434
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
435
436
437
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
438

439
440
441
442
443
444
445
446
447
            max_blocks = max(max_blocks, len(request_block_table))

        block_indices_to_free = []
        # Iterate on all requests
        for i, r in enumerate(self.requests):
            # Filter requests that are not part of the new batch
            if r.id not in requests_idx_mapping.keys():
                block_indices_to_free.extend(self.block_tables[i])
        # Free blocks
448
        get_cache_manager().free(block_indices_to_free)
449
450
451
        # Needed to avoid dropping blocks when the batches will go out of scope
        self.block_tables = None

452
453
454
455
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
456
457
458
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
459
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
460
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
OlivierDehaene's avatar
OlivierDehaene committed
461
462
463
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
464
465

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
466

467
        # Move to GPU now that we have the whole tensor
468
        slot_indices = slot_indices.to(device)
469

470
        return type(self)(
471
472
473
474
475
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
476
            cu_seqlen_prefill=None,
477
478
479
480
481
482
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
483
            max_seqlen=max_seqlen,
484
485
486
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
487
            input_lengths=input_lengths,
488
            input_lengths_tensor=input_lengths_tensor,
489
490
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
491
492
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
493
            next_token_chooser=next_token_chooser,
494
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
495
496
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
497
498
            blocks=blocks,
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
499
            speculative_ids=speculative_ids,
500
501
502
503
504
505
506
507
508
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

509
510
511
512
513
514
515
516
517
518
        blocks = 0
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
            blocks += b.blocks
OlivierDehaene's avatar
OlivierDehaene committed
519
520
521
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
522
523
524
525
526
527
528
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
529
                    + speculative_length
530
531
532
533
534
535
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
536
537
538

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
539
540
541
542
543
544
545
546
547
548
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
549
        )
Nicolas Patry's avatar
Nicolas Patry committed
550
551
552
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
553

554
555
        start_slots = []
        block_tables = []
556
557
558
        all_input_ids = []

        input_lengths = []
559
560
        prefix_offsets = []
        read_offsets = []
561

562
        next_token_chooser_parameters = []
563
        fsm_grammar_states = []
564
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
565
        top_n_tokens = []
566

567
        # Cumulative length
568
        cumulative_batch_size = 0
569
        cumulative_slots = 0
570
571
572

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
573
574
575
576
577
578
579
580

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

581
582
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
583
584
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
585
586
587
588

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
589
590
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
591
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
592
            slots[slots_start_index:slots_end_index] = batch.slots
593

594
595
596
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
597

598
599
600
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
601

602
603
604
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
605
606
            all_input_ids.extend(batch.all_input_ids)

607
            input_lengths.extend(batch.input_lengths)
608
609
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
610

611
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
612
            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
613
614
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
615
616
            top_n_tokens.extend(batch.top_n_tokens)

617
            # Update
618
            cumulative_batch_size += len(batch)
619
            cumulative_slots += len(batch.slots)
620

621
        start_slots = torch.concat(start_slots)
622

623
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
624
625
626
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
drbh's avatar
drbh committed
627
            tokenizer=batches[0].next_token_chooser.tokenizer,
628
            fsm_grammar_states=fsm_grammar_states,
629
630
        )

OlivierDehaene's avatar
OlivierDehaene committed
631
632
633
634
635
        speculative_ids = (
            torch.cat([b.speculative_ids for b in batches], dim=0)
            if batches[0].speculative_ids is not None
            else None
        )
Nicolas Patry's avatar
Nicolas Patry committed
636

637
638
639
        # Needed to avoid dropping blocks when the batches will go out of scope
        for b in batches:
            b.block_tables = None
640
            del b
641

642
        return cls(
643
644
            batch_id=batches[0].batch_id,
            requests=requests,
645
            requests_idx_mapping=requests_idx_mapping,
646
647
            input_ids=input_ids,
            position_ids=position_ids,
648
            cu_seqlen_prefill=None,
649
650
651
652
653
654
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
655
            max_seqlen=max_seqlen,
656
657
658
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
659
            input_lengths=input_lengths,
660
            input_lengths_tensor=input_lengths_tensor,
661
662
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
663
664
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
665
            next_token_chooser=next_token_chooser,
666
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
667
668
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
669
670
            blocks=blocks,
            max_blocks=max_blocks,
OlivierDehaene's avatar
OlivierDehaene committed
671
            speculative_ids=speculative_ids,
672
673
        )

674
675
676
    def __del__(self):
        if self.block_tables is not None and self.block_tables:
            # Free blocks
677
678
679
            get_cache_manager().free(
                list(itertools.chain.from_iterable(self.block_tables))
            )
680

681
682
683
684
685
686
687
    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
688
689
690
691
692
693
694
695
696
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        num_layers: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
697
        sliding_window: Optional[int] = None,
698
    ):
699
700
701
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
702

703
704
        self.cuda_graphs = {}

705
        super(FlashCausalLM, self).__init__(
706
            model=model,
707
708
709
710
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
711
712
            rank=rank,
            world_size=world_size,
713
            sliding_window=sliding_window,
714
715
716
717
718
719
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

720
721
722
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
723
        slots = torch.arange(bs, dtype=torch.int64, device=self.device)
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
        input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
        block_tables = (
            torch.arange(max_bt, dtype=torch.int32, device=self.device)
            .repeat(bs)
            .reshape((bs, max_bt))
        )
        kv_cache = get_cache_manager().kv_cache

        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "kv_cache": kv_cache,
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths,
        }
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs]["graph"] = graph

        torch.cuda.synchronize()
        # Run once outside to warmup
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=None,
            kv_cache=kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
            max_s=max_s,
            lm_head_indices=None,
        )
        torch.cuda.synchronize()

        with torch.cuda.graph(graph, pool=MEM_POOL):
759
            logits, speculative_logits = self.model.forward(
760
761
762
763
764
765
766
767
768
769
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                lm_head_indices=None,
            )
770
771
            self.cuda_graphs[bs]["logits"] = logits
            self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
772
773
        torch.cuda.synchronize()

774
    def warmup(self, batch: FlashCausalLMBatch):
775
        # The warmup batch is the biggest batch we could ever receive
Nicolas Patry's avatar
Nicolas Patry committed
776
777
        empty_cache()

778
        try:
779
            cache_manager = set_cache_manager(
780
                batch.blocks,
781
782
783
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
784
                self.sliding_window is not None,
785
786
787
                self.dtype,
                self.device,
            )
788
789
            max_bt = batch.max_blocks
            max_s = max_bt * get_cache_manager().block_size
fxmarty's avatar
fxmarty committed
790
791
792

            if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
                torch.cuda.tunable.tuning_enable(False)
793
            _, batch, _ = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
794
        except torch.cuda.OutOfMemoryError as e:
795
            raise RuntimeError(
796
797
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
798
            ) from e
799

Nicolas Patry's avatar
Nicolas Patry committed
800
        synchronize(self.device)
801

802
803
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
804
805
806
807
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

Nicolas Patry's avatar
Nicolas Patry committed
808
        free_memory = get_free_memory(self.device, MEMORY_FRACTION)
809
810

        num_blocks = (
811
812
            # Leave 5% for some wiggle room
            int((free_memory * 0.95) // total_cache_size)
813
            # Add batch.blocks as we allocated it above, so it is included in the peak memory.
814
            + cache_manager.num_blocks
815
816
        )

817
        del batch
818
        del cache_manager
819

820
        set_cache_manager(
821
822
823
824
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
825
            self.sliding_window is not None,
826
827
828
829
            self.dtype,
            self.device,
        )

fxmarty's avatar
fxmarty committed
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
        if SYSTEM == "rocm":
            if (
                os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
                or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
            ):
                if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
                    torch.cuda.tunable.tuning_enable(True)

                if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None:
                    tuning_sequences = [
                        int(val)
                        for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
                    ]
                else:
                    tuning_sequences = CUDA_GRAPHS

                tunableop_filepath = os.path.join(
                    HUGGINGFACE_HUB_CACHE,
                    f"tunableop_{tgi_globals.MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
                )

                logger.info(
                    f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`."
                )

                if os.path.isfile(tunableop_filepath):
                    logger.info(
                        f"The file {tunableop_filepath} already exists and will be reused."
                    )
                    torch.cuda.tunable.read_file(tunableop_filepath)

                os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)

                for seqlen in tuning_sequences:
                    logger.info(f"Warming up TunableOp for seqlen={seqlen}")
                    self.tunableop_warmup(seqlen)
                    torch.cuda.tunable.write_file(tunableop_filepath)
                torch.cuda.tunable.tuning_enable(False)
            else:
                logger.info(
                    "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp."
                )

873
        if CUDA_GRAPHS:
874
            try:
875
                logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
876
                # Warmup cuda graphs
877
                for bs in CUDA_GRAPHS:
878
879
                    if self.speculate is None or self.speculate + 1 <= bs:
                        self.cuda_graph_warmup(bs, max_s, max_bt)
OlivierDehaene's avatar
OlivierDehaene committed
880
            except torch.cuda.OutOfMemoryError:
881
                logger.exception(f"Decode cuda graph warmup failed")
882
883
        else:
            logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
884

885
        return int(num_blocks * BLOCK_SIZE)
886

fxmarty's avatar
fxmarty committed
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
    def tunableop_warmup(self, seqlen: int):
        input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
        slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
        kv_cache = get_cache_manager().kv_cache

        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=torch.tensor(
                [0, seqlen], device=self.device, dtype=torch.int32
            ),
            kv_cache=get_cache_manager().kv_cache,
            block_tables=None,
            input_lengths=None,
            slots=slots,
            max_s=seqlen,
            lm_head_indices=None,
        )

908
909
910
    def forward(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
911
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
912
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
913
914
915
916
917
918
919
920
921
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
922
923
924

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
925
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
926
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
927
928
929
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
930
931
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
932
933
934
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
935
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
OlivierDehaene's avatar
OlivierDehaene committed
936
937
938
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
939
940

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
941
942
943
944
945
946
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
947
948
949
950
951
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
952
953
954
955
956
957
958
959
960
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
961

962
        bs = input_ids.shape[0]
OlivierDehaene's avatar
OlivierDehaene committed
963
964
965
966
967
968
969
970
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None

        if cu_seqlen_prefill is not None or cuda_graph is None:
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
            return self.model.forward(
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=cu_seqlen_prefill,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                lm_head_indices=lm_head_indices,
            )

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
        cuda_graph["block_tables"][
            : block_tables.shape[0], : block_tables.shape[1]
        ] = block_tables
        cuda_graph["slots"].fill_(-1)
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths

        # Replay the graph
        cuda_graph["graph"].replay()
        # Slice output to the correct shape
998
999
1000
1001
1002
1003
1004
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
1005
1006
1007
1008

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
1009
1010
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
1011
        prefill = batch.cu_seqlen_prefill is not None
1012
        prefill_logprobs = batch.prefill_next_token_indices is not None
1013

1014
1015
        if batch.needed_blocks_slots:
            # Allocate blocks to this batch
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
            block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
                batch.needed_blocks_slots,
                batch.blocks,
                batch.max_blocks,
                batch.input_ids.device,
            )
            batch.needed_blocks_slots = None
            batch.block_tables = block_tables
            batch.block_tables_tensor = block_tables_tensor
            batch.slots = slots
1026

1027
        try:
1028
            out, speculative_logits = self.forward(batch)
1029
1030
1031
        except Exception as e:
            del batch
            raise e
1032

1033
1034
        if prefill:
            next_token_logits = (
1035
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
1036
            )
Nicolas Patry's avatar
Nicolas Patry committed
1037
1038
            if speculative_logits is not None:
                speculative_logits = (
OlivierDehaene's avatar
OlivierDehaene committed
1039
1040
1041
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
Nicolas Patry's avatar
Nicolas Patry committed
1042
                )
1043
1044
1045
        else:
            next_token_logits = out

Nicolas Patry's avatar
Nicolas Patry committed
1046
        speculate = get_speculate()
OlivierDehaene's avatar
OlivierDehaene committed
1047
1048
1049
1050
1051
1052
1053
1054
1055
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen],
            next_token_logits,
Nicolas Patry's avatar
Nicolas Patry committed
1056
            speculate,
OlivierDehaene's avatar
OlivierDehaene committed
1057
1058
            batch.speculative_ids,
            speculative_logits,
1059
1060
        )

Nicolas Patry's avatar
Nicolas Patry committed
1061
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
1062
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1063
1064
        )

1065
        if prefill:
1066
            if len(batch) > 1 and prefill_logprobs:
1067
1068
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
1069
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
1070
1071

            next_position_ids = batch.position_ids.new_empty(len(batch))
1072
1073
1074
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
1075
1076
1077
1078
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

1079
1080
1081
1082
1083
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
1084
        stopped = True
1085
1086

        # Zipped iterator
OlivierDehaene's avatar
OlivierDehaene committed
1087
        iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
1088

1089
1090
1091
1092
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

1093
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1094
        index = 0
OlivierDehaene's avatar
OlivierDehaene committed
1095
        for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
1096
            # Indexing metadata
1097
1098
1099
            start_index = cumulative_length
            end_index = cumulative_length + input_length

1100
            if prefill:
1101
1102
1103
1104
1105
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

1106
1107
1108
1109
1110
1111
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
1112
1113
                if prefill_logprobs:
                    if len(batch) > 1:
drbh's avatar
drbh committed
1114
1115
1116
                        prefill_tokens_indices[out_start_index : out_end_index - 1] = (
                            batch.input_ids[start_index + 1 : start_index + out_length]
                        )
1117
1118
1119
1120
1121
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
1122

Nicolas Patry's avatar
Nicolas Patry committed
1123
1124
1125
            for j in range(n_accepted_ids):
                batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
                index += 1
1126
1127
1128

            cumulative_length += input_length

drbh's avatar
drbh committed
1129
        # Update values
Nicolas Patry's avatar
Nicolas Patry committed
1130
1131
1132
1133
1134
        batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
        batch.speculative_ids = speculative_ids
        batch.position_ids = next_position_ids + accepted_ids
        batch.input_lengths_tensor += accepted_ids
        batch.slot_indices += accepted_ids
1135

1136
        if prefill and prefill_logprobs:
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
1147
        next_token_ids = next_input_ids.tolist()
1148
1149
        accepted_ids = accepted_ids.tolist()
        start_decode = time.time_ns()
1150
1151
1152
1153
1154

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
1155
1156
            batch.prefix_offsets,
            batch.read_offsets,
1157
1158
            batch.stopping_criterias,
            batch.all_input_ids,
1159
1160
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
1161
            batch.top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1162
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1163
1164
            batch_top_token_ids,
            batch_top_token_logprobs,
1165
1166
1167
        )

        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1168
        index = 0
1169
1170
1171
        for i, (
            request,
            input_length,
1172
1173
            prefix_offset,
            read_offset,
1174
1175
            stopping_criteria,
            all_input_ids,
1176
1177
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
1178
            top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1179
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1180
1181
            top_token_ids,
            top_token_logprobs,
1182
        ) in enumerate(iterator):
1183
            # Append next token to all tokens
Nicolas Patry's avatar
Nicolas Patry committed
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
            next_token_texts = []
            left = 0

            current_stopped = False
            for j in range(index, index + n_accepted_ids):
                # Generated token
                next_token_id = next_token_ids[j]
                all_input_ids.append(next_token_id)
                next_token_text, prefix_offset, read_offset = self.decode_token(
                    all_input_ids,
                    prefix_offset,
                    read_offset,
                )
                next_token_texts.append(next_token_text)
1198

Nicolas Patry's avatar
Nicolas Patry committed
1199
1200
1201
1202
                stop, reason = stopping_criteria(
                    next_token_id,
                    next_token_text,
                )
1203

Nicolas Patry's avatar
Nicolas Patry committed
1204
1205
1206
1207
1208
1209
1210
                if stop:
                    left = index + n_accepted_ids - j - 1
                    current_stopped = True
                    break
                else:
                    current_stopped = False
            stopped = stopped and current_stopped
1211

OlivierDehaene's avatar
OlivierDehaene committed
1212
1213
1214
1215
            _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
            _next_token_logprobs = next_token_logprobs[
                index : index + n_accepted_ids - left
            ]
Nicolas Patry's avatar
Nicolas Patry committed
1216
            index += n_accepted_ids
1217

1218
1219
1220
1221
1222
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
1223
1224
                    output_text, _, _ = self.decode_token(
                        all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1225
1226
1227
1228
1229
1230
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
1231
1232
                    )
                    generated_text = GeneratedText(
1233
1234
1235
1236
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
1237
1238
1239
1240
1241
                    )
                else:
                    generated_text = None

                # Prefill
1242
1243
1244
1245
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

1246
1247
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
1248
                        out_start_index : out_end_index - 1
1249
1250
1251
1252
1253
1254
1255
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1256
1257

                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
1258
1259
1260
1261
                        prefill_token_ids,
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
1262
1263
1264
1265
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
1266
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
1267
                    all_top_tokens = []
drbh's avatar
drbh committed
1268
                    for top_token_ids, top_token_logprobs in zip(
1269
1270
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
1271
1272
1273
1274
1275
1276
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
1277
1278
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
1279
1280
1281
1282
1283
1284
1285
1286
1287
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
1288
1289
1290
                else:
                    top_tokens = None

1291
1292
1293
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1294
1295
1296
1297
1298
1299
                    Tokens(
                        _next_token_ids,
                        _next_token_logprobs,
                        next_token_texts,
                        [nid in self.all_special_ids for nid in _next_token_ids],
                    ),
1300
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
1301
                    top_tokens,
1302
1303
                )

1304
                generations.append(generation)
1305

drbh's avatar
drbh committed
1306
1307
1308
            # accept each new token for this specific request since we may
            # have more than one new token per request with speculative decoding
            for next_token_id in _next_token_ids:
OlivierDehaene's avatar
OlivierDehaene committed
1309
1310
1311
                batch.next_token_chooser = (
                    batch.next_token_chooser.advance_grammar_single(i, next_token_id)
                )
drbh's avatar
drbh committed
1312

1313
            # Update values
1314
            batch.input_lengths[i] = input_length + n_accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1315
1316
            if batch.input_lengths[i] > batch.max_seqlen:
                batch.max_seqlen = batch.input_lengths[i]
1317
1318
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1319
1320
            batch.all_input_ids[i] = all_input_ids

1321
1322
1323
        if stopped:
            del batch
            # No need to return a batch if we know that all requests stopped
1324
1325
1326
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
1327

1328
1329
1330
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1331

1332
1333
1334
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)