flash_causal_lm.py 93.9 KB
Newer Older
1
from contextlib import nullcontext
2
import math
3
import os
4
import time
5
6
7
import torch
import torch.distributed

8
9
import numpy as np

10
from loguru import logger
11
12
from dataclasses import dataclass
from opentelemetry import trace
13
14
15
16
17
18
from transformers import (
    PreTrainedTokenizerBase,
    AutoConfig,
    AutoTokenizer,
    GenerationConfig,
)
19
20
21
22
23
24
25
26
27
28
29
from typing import (
    Any,
    ContextManager,
    Iterable,
    Optional,
    Tuple,
    List,
    Type,
    Dict,
    Union,
)
fxmarty's avatar
fxmarty committed
30

drbh's avatar
drbh committed
31
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
fxmarty's avatar
fxmarty committed
32
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
Daniël de Kok's avatar
Daniël de Kok committed
33
from text_generation_server.utils.chunks import concat_text_chunks
Nicolas Patry's avatar
Nicolas Patry committed
34
from text_generation_server.utils.import_utils import SYSTEM
OlivierDehaene's avatar
OlivierDehaene committed
35
from text_generation_server.models import Model
36
from text_generation_server.utils.log import log_master
37
38
39
40
from text_generation_server.utils.prefill_chunking import (
    get_support_chunking,
    get_max_prefill_tokens,
)
41
from text_generation_server.utils.tokens import batch_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
42
from text_generation_server.utils.speculate import get_speculate
43
44
45
46
47
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
)
48
49
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
50
    Tokens,
51
52
53
54
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
Nicolas Patry's avatar
Nicolas Patry committed
55
56
from text_generation_server.models.globals import (
    MEM_POOL,
57
    ATTENTION,
58
    BLOCK_SIZE,
Nicolas Patry's avatar
Nicolas Patry committed
59
    CUDA_GRAPHS,
60
    TGI_WIGGLE_ROOM,
Nicolas Patry's avatar
Nicolas Patry committed
61
62
    get_adapter_to_index,
)
63
from text_generation_server.layers.attention import KVCache, Seqlen
64
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
65
from text_generation_server.utils.dist import MEMORY_FRACTION
66
from text_generation_server.utils.quantization import get_loader
drbh's avatar
drbh committed
67
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
68

Nicolas Patry's avatar
Nicolas Patry committed
69
from text_generation_server.utils.import_utils import (
Nicolas Patry's avatar
Nicolas Patry committed
70
71
72
    empty_cache,
    synchronize,
    get_free_memory,
Nicolas Patry's avatar
Nicolas Patry committed
73
)
74
75
76
77
78
79
80
81
from text_generation_server.models.metadata_kernels import (
    has_triton,
    copy_next_input_ids_inplace,
    block_tables_to_ragged,
    block_tables_to_padded,
    prepare_position_slot_ids,
    slots_filtering,
)
Nicolas Patry's avatar
Nicolas Patry committed
82

Nicolas Patry's avatar
Nicolas Patry committed
83
84
tracer = trace.get_tracer(__name__)

85
86
87
88
# Will be set in init
SLIDING_WINDOW: Optional[int] = None


89
90
91
92
def small_power_of_2(n: int):
    return 1 << ((n - 1).bit_length() - 1)


93
94
95
96
97
98
99
100
101
def set_sliding_window(sliding_window: int):
    global SLIDING_WINDOW
    SLIDING_WINDOW = sliding_window


def get_sliding_windows() -> int:
    global SLIDING_WINDOW
    return SLIDING_WINDOW

102

103
104
105
106
107
108
109
def init_cpu_threads_env(rank_id: int, world_size: int):
    import importlib.util

    if importlib.util.find_spec("numa") is not None:
        import numa
        import psutil

110
        nodes = numa.info.get_max_node() + 1
111
112
113
114
115
116
117
118
        rank_per_node = math.ceil(world_size / nodes)
        num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes)
        node_id = int(rank_id / rank_per_node)
        rank_offset_per_node = rank_id % rank_per_node
        if os.getenv("OMP_NUM_THREADS") is None:
            num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1)
        else:
            num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS"))
119
120
        if len(numa.memory.get_membind_nodes()) == nodes:
            numa.memory.set_membind_nodes((node_id))
121
        torch.set_num_threads(num_cpus_per_rank)
122
        if len(numa.schedule.get_affinitive_cpus(0)) == psutil.cpu_count(logical=True):
123
            cpu_start = num_cpus_per_rank * rank_offset_per_node
124
            numa.schedule.run_on_cpus(
125
                0,
126
127
128
129
130
                *(
                    numa.info.node_to_cpus(node_id)[
                        cpu_start : cpu_start + num_cpus_per_rank
                    ]
                ),
131
            )
132
133
134
        logger.info(
            f"affinity={numa.schedule.get_affinitive_cpus(0)}, membind = {numa.memory.get_membind_nodes()}"
        )
135
136


137
138
139
140
@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
141
142
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
143
144

    # Decoder values
145
146
147
148
149
    # Can be a list for easy filtering
    # If `input_ids` is a list, it needs to be materialized to a tensor first
    input_ids: Union[torch.Tensor, List[List[int]]]
    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
    position_ids: Optional[torch.Tensor]
150
    speculative_ids: Optional[torch.Tensor]
151

152
153
    # Set when creating the batch
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
154
155
    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
    slot_indices: Optional[torch.Tensor]
156
157

    # list of length b of list of length s_i // block_size
158
    block_tables: List[List[int]]
159
    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
160
    block_tables_tensor: torch.Tensor
161
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
162
163
164
165
    slots: torch.Tensor
    # list of length b + 1  containing the cumulative sequence slot lengths of the sequences in the batch
    # used for filtering
    cu_slots: torch.Tensor
166

167
168
169
170
171
172
173
    max_input_length: int
    max_current_length: int

    # Whether this batch contains at least one request that is prefilling
    prefilling: bool
    # Whether each request is prefilling
    prefilling_mask: List[bool]
174

175
    # Prefill metadata tensors to efficiently compute logprobs
176
    # tensor of length b + 1  containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
177
178
179
180
181
    cu_seqlen_prefill: Optional[torch.Tensor]
    # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
    # as we only keep SLIDING_WINDOW values instead of the whole tensor
    prefill_cache_indices: Optional[torch.Tensor]
    # Will be set by `generate_token` and reset after each prefill forward
182
    prefill_head_indices: Optional[torch.Tensor]
183
    # Will be set by `generate_token` and reset after each prefill forward
184
    prefill_next_token_indices: Optional[torch.tensor]
185
    # Will be set by `generate_token` and reset after each prefill forward
186
    prefill_cu_outlens: Optional[List[int]]
187
188
    # Will be set by `generate_token` and reset after each prefill forward
    prefill_logprob_tokens: List[Optional[Tokens]]
Nicolas Patry's avatar
Nicolas Patry committed
189

190
191
    # All tokens
    all_input_ids: List[List[int]]
192
    all_input_ids_tensor: torch.Tensor
193
194
195

    # Lengths of all generations present in the batch
    input_lengths: List[int]
196
197
198
199
200
201
202
203
    # size [b], containing the number of blocks that can be retrieved from the cache
    cache_lengths: List[int]
    prompt_lengths: List[int]
    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
    input_lengths_tensor: Optional[torch.Tensor]
    cache_lengths_tensor: Optional[torch.Tensor]
    prompt_lengths_tensor: torch.Tensor

204
205
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
206
207

    # Generation helpers
208
    next_token_chooser: HeterogeneousNextTokenChooser
209
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
210
211
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
212

drbh's avatar
drbh committed
213
    # Adapter metadata for each request
214
215
    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
    adapter_meta: Optional[AdapterBatchMetadata]
drbh's avatar
drbh committed
216

217
    # Number of blocks in this batch
218
    num_blocks: int
219
220
    # Maximum number of blocks
    max_blocks: int
221

222
223
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
224
            id=self.batch_id,
225
            request_ids=[r.id for r in self.requests],
226
            size=len(self),
227
            max_tokens=self.num_blocks * BLOCK_SIZE,
228
229
230
231
232
            current_tokens=(
                sum([len(i) for i in self.input_ids])
                if isinstance(self.input_ids, list)
                else len(self.input_ids)
            ),
233
234
235
        )

    @classmethod
Daniël de Kok's avatar
Daniël de Kok committed
236
237
238
    def batch_tokenized_inputs(
        cls, requests: Iterable[generate_pb2.Request], tokenizer
    ):
239
240
241
        max_length = 0
        all_input_ids = []
        batch_size = 0
242
        for r in requests:
243
244
245
246
247
248
249
250
251
252
253
            batch_size += 1
            inputs = concat_text_chunks(r.input_chunks.chunks)
            input_ids = tokenizer(
                inputs,
                truncation=True,
                max_length=r.truncate,
                add_special_tokens=r.add_special_tokens,
            )["input_ids"]
            max_length = max(max_length, len(input_ids))
            all_input_ids.append(input_ids)
        return all_input_ids
254

drbh's avatar
drbh committed
255
256
257
258
259
260
261
262
263
    @classmethod
    def from_tokenized(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        batch_tokenized_inputs,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
264
        speculate = get_speculate()
265

266
        cache_lengths = []
267
        input_lengths = []
268
        prompt_lengths = []
269
270
        prefix_offsets = []
        read_offsets = []
271
        all_input_ids = []
272
        all_postfix_ids = []
273
        requests_idx_mapping = {}
274
275
        slots = []
        cu_slots = [0]
276

277
        next_token_chooser_parameters = []
278
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
279
        top_n_tokens = []
280

281
        num_blocks = 0
282
283
        max_input_length = 0
        max_current_length = 0
284
        max_length = 0
285
        max_blocks = 0
286

287
        cu_blocks = [0]
288
        block_tables = []
289
        block_tables_ragged = []
290

291
        # Parse batch
292
293
294
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
295
296
297
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

298
299
300
301
            prompt_length = len(tokenized_input)
            prompt_lengths.append(prompt_length)

            cache_length = r.cache_len
Nicolas Patry's avatar
Nicolas Patry committed
302

303
            assert (
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
                cache_length <= prompt_length
            ), f"Prefix {cache_length} vs input {prompt_length}"
            if cache_length == prompt_length:
                assert False, "unreachable"

            # `chunk_len` is an optional field in the protobuf
            # It is only set if the model support chunking
            if r.HasField("chunk_len"):
                input_length = r.chunk_len

                if cache_length + input_length < prompt_length:
                    # FIXME: speculate is not supported for context chunking at the moment
                    assert speculate == 0
                    assert get_support_chunking()
                    assert input_length > 0

                postfix_ids = tokenized_input[
                    cache_length : cache_length + input_length
                ]
                assert (
                    len(postfix_ids) == input_length
                ), "Rust and Python tokenizers are not aligned"
            else:
                # Use all the remaining ids
                postfix_ids = tokenized_input[cache_length:]
                input_length = len(postfix_ids)

331
            input_lengths.append(input_length)
332

333
334
            prefix_offsets.append(prompt_length - 5)
            read_offsets.append(prompt_length)
335

336
            all_postfix_ids.append(postfix_ids)
337
            all_input_ids.append(tokenized_input)
338

339
            next_token_chooser_parameters.append(r.parameters)
340

341
342
343
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
344
            max_new_tokens = stopping_criteria.max_new_tokens
345
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
346
            top_n_tokens.append(r.top_n_tokens)
347

348
349
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
350
            speculative_length = get_speculate()
drbh's avatar
drbh committed
351
            speculative_length = 0 if speculative_length is None else speculative_length
Nicolas Patry's avatar
Nicolas Patry committed
352
353

            # Tokens that need to be mapped to blocks.
354
            block_tokens = prompt_length + max_new_tokens - 1 + speculative_length
355
356
357

            # blocks and slots can be empty (for example in warmup)
            if not r.blocks:
Nicolas Patry's avatar
Nicolas Patry committed
358
                needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
359
360
361
                request_blocks = [
                    b for b in range(num_blocks, num_blocks + needed_blocks)
                ]
362
363
364
365
366
                request_slots = [
                    s
                    for b in request_blocks
                    for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
                ]
367
368
            else:
                request_blocks = r.blocks
369
                request_slots = r.slots
370
371

            block_tables.append(request_blocks)
372
373
374
375
376
            block_tables_ragged.extend(request_blocks)
            cu_blocks.append(len(block_tables_ragged))

            slots.extend(request_slots)
            cu_slots.append(len(slots))
Nicolas Patry's avatar
Nicolas Patry committed
377

378
            cache_lengths.append(cache_length)
379
            num_blocks += len(request_blocks)
380

381
            # Update
382
            max_blocks = max(max_blocks, len(request_blocks))
383
384
            max_input_length = max(max_input_length, input_length)
            max_current_length = max(max_current_length, cache_length + input_length)
OlivierDehaene's avatar
OlivierDehaene committed
385
            max_length = max(
386
387
                max_length,
                prompt_length + max_new_tokens + speculative_length,
OlivierDehaene's avatar
OlivierDehaene committed
388
            )
389
390

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
391
            next_token_chooser_parameters, dtype, device, tokenizer
392
393
394
395
396
397
398
399
        )

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
400

401
402
403
404
405
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

Nicolas Patry's avatar
Nicolas Patry committed
406
407
408
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
409

410
411
        block_tables_ragged = torch.tensor(
            block_tables_ragged, device=device, dtype=torch.int32
412
        )
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
        cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64)
        block_tables_tensor = torch.empty(
            (len(block_tables), max_blocks),
            device=device,
            dtype=torch.int32,
        )

        # If the device supports Triton, we can use a fused kernel
        if has_triton():
            block_tables_to_padded(
                max_blocks, cu_blocks, block_tables_tensor, block_tables_ragged
            )
        else:
            for i, request_blocks in enumerate(block_tables):
                block_tables_tensor[i, : len(request_blocks)] = torch.tensor(
                    request_blocks
                )

431
432
433
        prompt_lengths_tensor = torch.tensor(
            prompt_lengths, dtype=torch.int32, device=device
        )
434

435
436
437
        slots = torch.tensor(slots, dtype=torch.int64, device=device)
        cu_slots = torch.tensor(cu_slots, dtype=torch.int64)

438
439
440
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
441
            requests_idx_mapping=requests_idx_mapping,
442
            input_ids=all_postfix_ids,
443
444
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
445
446
447
448
449
450
            cache_lengths=cache_lengths,
            max_input_length=max_input_length,
            max_current_length=max_current_length,
            prefilling=True,
            prefilling_mask=[True] * len(pb.requests),
            prefill_logprob_tokens=[None] * len(pb.requests),
451
            input_lengths=input_lengths,
452
            prompt_lengths=prompt_lengths,
453
454
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
455
            all_input_ids=all_input_ids,
456
457
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
458
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
459
460
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
461
            num_blocks=num_blocks,
462
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
463
            speculative_ids=None,
464
465
466
467
468
469
            prompt_lengths_tensor=prompt_lengths_tensor,
            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
            position_ids=None,
            cu_seqlen_prefill=None,
            prefill_cache_indices=None,
            slot_indices=None,
470
471
            slots=slots,
            cu_slots=cu_slots,
472
473
474
475
476
477
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
            cache_lengths_tensor=None,
            input_lengths_tensor=None,
            adapter_meta=None,
478
479
        )

480
481
482
483
484
485
486
487
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
488
        assert len(pb.requests) > 0
489
490
491
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
        return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)

492
    @tracer.start_as_current_span("filter")
493
494
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
495
496
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
497
        if len(request_ids) == len(self):
498
499
            return self

500
        device = self.block_tables_tensor.device
501

502
503
504
        # New values after filtering
        requests_idx_mapping = {}

505
506
507
        # Used to index into tensors
        indices = []

508
509
510
511
512
        if not has_triton():
            # slots to keep after filtering
            slot_filtering_indices = torch.zeros(
                self.slots.shape[0], dtype=torch.bool, device=device
            )
513

514
        # Create on CPU to only move to GPU once instead of at every copy
515
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
516
517
        max_input_length = 0
        max_current_length = 0
518

519
        requests = []
520
        block_tables = []
521
        all_input_ids = []
522
        input_ids = []
523

524
        prompt_lengths = []
525
        input_lengths = []
526
        cache_lengths = []
527
528
        prefix_offsets = []
        read_offsets = []
529
        cu_slots = [0]
530

531
532
533
        prefilling_mask = []
        prefill_logprob_tokens = []

534
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
535
        top_n_tokens = []
drbh's avatar
drbh committed
536
        adapter_set = set()
537

538
        num_blocks = 0
539
        max_blocks = 0
540
541
        max_slots = 0
        cumulative_slot_tokens = 0
542

543
544
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
545
            indices.append(idx)
546
547
548
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
549

550
551
552
553
            # Prefilling
            request_prefilling = self.prefilling_mask[idx]
            prefilling_mask.append(request_prefilling)

554
555
            # Get length
            request_input_length = self.input_lengths[idx]
556
557
558
559
560
            request_cache_length = self.cache_lengths[idx]
            max_input_length = max(max_input_length, request_input_length)
            max_current_length = max(
                max_current_length, request_cache_length + request_input_length
            )
561

562
563
            all_input_ids.append(self.all_input_ids[idx])

564
            prompt_lengths.append(self.prompt_lengths[idx])
565
            input_lengths.append(request_input_length)
566
            cache_lengths.append(request_cache_length)
567
568
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
569

570
571
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
572

Nicolas Patry's avatar
Nicolas Patry committed
573
            top_n_tokens.append(self.top_n_tokens[idx])
574
            prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
Nicolas Patry's avatar
Nicolas Patry committed
575

Nicolas Patry's avatar
Nicolas Patry committed
576
577
            ADAPTER_TO_INDEX = get_adapter_to_index()
            adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
drbh's avatar
drbh committed
578
579
            adapter_set.add(adapter_index)

580
            request_block_table = self.block_tables[idx]
581
            num_blocks += len(request_block_table)
582
583
            block_tables.append(request_block_table)

584
585
586
587
588
589
590
591
592
593
            start_slot = self.cu_slots[idx]
            end_slot = self.cu_slots[idx + 1]
            slot_length = end_slot - start_slot

            if not has_triton():
                # Set slice
                slot_filtering_indices[start_slot:end_slot] = True

            cu_slots.append(cumulative_slot_tokens + slot_length)

594
595
596
597
598
599
            # Input ids if the request was part of a prefilling batch
            # If the batch was decoding we can index into the tensor directly later
            if self.prefilling:
                input_ids.append(self.input_ids[idx])
            else:
                # Copy to tensor (CPU)
600
                slot_indices[i] = cumulative_slot_tokens + request_cache_length
601

602
            cumulative_slot_tokens += slot_length
603
            max_blocks = max(max_blocks, len(request_block_table))
604
            max_slots = max(max_slots, slot_length)
605

606
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
607
        block_tables_tensor = self.block_tables_tensor[indices]
608
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
609
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
OlivierDehaene's avatar
OlivierDehaene committed
610
611
612
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
613
614
        prompt_lengths_tensor = self.prompt_lengths_tensor[indices]

615
616
617
618
619
620
621
622
623
624
625
626
        cu_slots = torch.tensor(cu_slots, dtype=torch.int64)

        if not has_triton():
            slots = self.slots[slot_filtering_indices]
        else:
            slots = self.slots.new_empty(cumulative_slot_tokens)
            gpu_cu_slots = cu_slots.to(device)
            slots_indexing_start = self.cu_slots.to(device)[indices]
            slots_filtering(
                max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start
            )

627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
        if self.prefilling:
            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
            position_ids = None
            slot_indices = None
            cache_lengths_tensor = None
            input_lengths_tensor = None
            adapter_meta = None
        else:
            # Index into tensors
            input_ids = self.input_ids[indices]
            position_ids = self.position_ids[indices]
            adapter_indices = self.adapter_meta.adapter_indices[indices]
            input_lengths_tensor = self.input_lengths_tensor[indices]
            cache_lengths_tensor = self.cache_lengths_tensor[indices]

            # Move to GPU now that we have the whole tensor
            slot_indices = slot_indices.to(device)

            adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
            adapter_segments = torch.tensor(
                adapter_segments, dtype=torch.int32, device=device
            )
            adapter_meta = AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            )
drbh's avatar
drbh committed
655

656
        return type(self)(
657
658
659
660
661
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
662
            cu_seqlen_prefill=None,
663
            prefill_cache_indices=None,
664
665
666
667
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
668
            cu_slots=cu_slots,
669
670
671
672
            max_input_length=max_input_length,
            max_current_length=max_current_length,
            prefilling=self.prefilling,
            prefilling_mask=prefilling_mask,
673
674
675
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
676
677
678
            prefill_logprob_tokens=prefill_logprob_tokens,
            prompt_lengths=prompt_lengths,
            prompt_lengths_tensor=prompt_lengths_tensor,
679
            input_lengths=input_lengths,
680
            input_lengths_tensor=input_lengths_tensor,
681
682
            cache_lengths=cache_lengths,
            cache_lengths_tensor=cache_lengths_tensor,
683
684
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
685
686
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
687
            next_token_chooser=next_token_chooser,
688
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
689
690
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
691
            num_blocks=num_blocks,
692
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
693
            speculative_ids=speculative_ids,
694
            adapter_meta=adapter_meta,
695
696
697
698
699
700
701
702
703
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

704
        prefilling = False
705
        num_blocks = 0
706
707
708
709
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
710
711
        max_input_length = 0
        max_current_length = 0
712
713
        for b in batches:
            total_batch_size += len(b)
714
            max_blocks = max(max_blocks, b.max_blocks)
715
            total_slots += len(b.slots)
716
            num_blocks += b.num_blocks
OlivierDehaene's avatar
OlivierDehaene committed
717
718
719
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
720
721
            max_input_length = max(max_input_length, b.max_input_length)
            max_current_length = max(max_current_length, b.max_current_length)
722
723
724
            max_length = max(
                max_length,
                max(
725
                    prompt_length
726
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
727
                    + speculative_length
728
729
                    for prompt_length, stopping_criteria in zip(
                        b.prompt_lengths, b.stopping_criterias
730
731
732
                    )
                ),
            )
733
734
            prefilling = prefilling or b.prefilling

735
736
        slots = batches[0].slots.new_empty(total_slots)
        cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64)
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
        if prefilling:
            input_ids = []
            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
            position_ids = None
            slot_indices = None
            cache_lengths_tensor = None
            input_lengths_tensor = None
            adapter_meta = None
            adapter_segment_builder = None
        else:
            input_ids = batches[0].input_ids.new_empty(total_batch_size)
            position_ids = batches[0].position_ids.new_empty(total_batch_size)
            slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
            input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
                total_batch_size
            )
            cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
                total_batch_size
            )
            total_indices_size = sum(
                b.adapter_meta.adapter_indices.shape[0] for b in batches
            )
            adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
                total_indices_size
            )
            adapter_segment_builder = SegmentConcatBuilder()
            adapter_set = set()
764

765
        prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
766
767
768
769
770
771
772
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
773
        )
Nicolas Patry's avatar
Nicolas Patry committed
774
775
776
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
777

778
        block_tables = []
779
        cache_lengths = []
780
781
        all_input_ids = []

782
        prompt_lengths = []
783
        input_lengths = []
784
785
        prefix_offsets = []
        read_offsets = []
786

787
788
        prefill_logprob_tokens = []

789
        next_token_chooser_parameters = []
790
        fsm_grammar_states = []
791
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
792
        top_n_tokens = []
793
        prefilling_mask = []
794

795
        # Cumulative length
796
        cumulative_batch_size = 0
797
        cumulative_slots = 0
drbh's avatar
drbh committed
798
        cumulative_adapter_indices_size = 0
799
800
801

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
802
803
804
805
806
807
808
809

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

810
811
812
813
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)

            # Copy tensors (GPU)
Nicolas Patry's avatar
Nicolas Patry committed
814
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
815
816
817
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
818

819
820
821
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
822
            prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
823

824
825
826
827
828
829
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
            slots[slots_start_index:slots_end_index] = batch.slots
            cu_slots[start_index + 1 : end_index + 1] = (
                batch.cu_slots[1:] + cumulative_slots
            )
830

831
            if not prefilling:
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
                input_ids[start_index:end_index] = batch.input_ids
                position_ids[start_index:end_index] = batch.position_ids
                slot_indices[start_index:end_index] = (
                    batch.slot_indices + cumulative_slots
                )
                input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
                cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor

                # Copy over adapter indices
                adapter_start_index = cumulative_adapter_indices_size
                adapter_end_index = (
                    cumulative_adapter_indices_size
                    + batch.adapter_meta.adapter_indices.shape[0]
                )
                adapter_indices[adapter_start_index:adapter_end_index] = (
                    batch.adapter_meta.adapter_indices
                )
                cumulative_adapter_indices_size = adapter_end_index
                adapter_set.update(batch.adapter_meta.adapter_set)
                adapter_segment_builder.concat(
                    batch.adapter_meta.adapter_segments,
                    batch.adapter_meta.segment_indices,
                )
            else:
                if isinstance(batch.input_ids, torch.Tensor):
                    batch.input_ids = batch.input_ids.view(-1, 1).tolist()
                input_ids.extend(batch.input_ids)
859

860
            prefilling_mask.extend(batch.prefilling_mask)
861
            block_tables.extend(batch.block_tables)
862
            cache_lengths.extend(batch.cache_lengths)
863
864
            all_input_ids.extend(batch.all_input_ids)

865
            prompt_lengths.extend(batch.prompt_lengths)
866
            input_lengths.extend(batch.input_lengths)
867
868
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
869

870
871
            prefill_logprob_tokens.extend(batch.prefill_logprob_tokens)

872
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
873
            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
874
875
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
876
877
            top_n_tokens.extend(batch.top_n_tokens)

878
            # Update
879
            cumulative_slots += len(batch.slots)
880
            cumulative_batch_size += len(batch)
881

882
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
883
884
885
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
drbh's avatar
drbh committed
886
            tokenizer=batches[0].next_token_chooser.tokenizer,
887
            fsm_grammar_states=fsm_grammar_states,
888
889
        )

890
891
892
893
894
895
        # We skip computing the speculative_ids when the batch size is too large, so
        # we must check that all batches have them, otherwise they must be discarded
        if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches):
            speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
        else:
            speculative_ids = None
Nicolas Patry's avatar
Nicolas Patry committed
896

897
898
899
900
901
902
903
904
        if adapter_segment_builder is not None:
            adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
            adapter_meta = AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            )
drbh's avatar
drbh committed
905

906
        return cls(
907
908
            batch_id=batches[0].batch_id,
            requests=requests,
909
            requests_idx_mapping=requests_idx_mapping,
910
911
            input_ids=input_ids,
            position_ids=position_ids,
912
            cu_seqlen_prefill=None,
913
            prefill_cache_indices=None,
914
915
916
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
917
918
            cache_lengths=cache_lengths,
            cache_lengths_tensor=cache_lengths_tensor,
919
            slots=slots,
920
            cu_slots=cu_slots,
921
922
923
924
            max_input_length=max_input_length,
            max_current_length=max_current_length,
            prefilling=prefilling,
            prefilling_mask=prefilling_mask,
925
926
927
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
928
929
930
            prefill_logprob_tokens=prefill_logprob_tokens,
            prompt_lengths=prompt_lengths,
            prompt_lengths_tensor=prompt_lengths_tensor,
931
            input_lengths=input_lengths,
932
            input_lengths_tensor=input_lengths_tensor,
933
934
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
935
936
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
937
            next_token_chooser=next_token_chooser,
938
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
939
940
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
941
            num_blocks=num_blocks,
942
            max_blocks=max_blocks,
OlivierDehaene's avatar
OlivierDehaene committed
943
            speculative_ids=speculative_ids,
944
945
946
947
948
949
950
951
952
            adapter_meta=adapter_meta,
        )

    def prepare_for_prefill(self):
        # Prepare values if we need to continue prefilling
        # Speculation must be ignored while we prefill even with chunking
        # it simplifies everything
        assert self.speculative_ids is None

953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
        device = self.block_tables_tensor.device

        if isinstance(self.input_ids, list):
            if len(self) > 1:
                input_ids = np.concatenate(self.input_ids, dtype=np.int64)
            else:
                input_ids = self.input_ids[0]
            self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)

        self.input_lengths_tensor = torch.tensor(
            self.input_lengths, dtype=torch.int32, device=device
        )
        self.cu_seqlen_prefill = torch.nn.functional.pad(
            torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0)
        ).to(torch.int32)
        self.cache_lengths_tensor = torch.tensor(
            self.cache_lengths, dtype=torch.int32, device=device
        )

        # If the device supports Triton, we can use a fused kernel
        if has_triton():
            self.position_ids = torch.empty(
                len(self.input_ids), dtype=torch.int32, device=device
            )
            self.slot_indices = torch.empty(
                len(self.input_ids), dtype=torch.int64, device=device
            )
            cu_slots_gpu = self.cu_slots.to(device)

            prepare_position_slot_ids(
                self.max_input_length,
                self.cache_lengths_tensor,
                self.cu_seqlen_prefill,
                cu_slots_gpu,
                self.position_ids,
                self.slot_indices,
            )

991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
        sliding_window = get_sliding_windows()
        position_ids = []
        slot_indices = []
        prefill_cache_indices = []
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_cu_outlens = [0]

        # Cumulative length
        cumulative_length = 0
        cumulative_slot_tokens = 0
        prefill_out_cumulative_length = 0

        adapter_indices_list = []
        adapter_set = set()

        for i, (
            r,
            cache_length,
            input_length,
            prompt_length,
            request_prefilling,
            blocks,
        ) in enumerate(
            zip(
                self.requests,
                self.cache_lengths,
                self.input_lengths,
                self.prompt_lengths,
                self.prefilling_mask,
                self.block_tables,
            )
        ):
            next_chunk_length = input_length

1026
1027
1028
1029
1030
1031
            if not has_triton():
                # Position ids
                request_position_ids = torch.arange(
                    cache_length, cache_length + input_length, dtype=torch.int32
                )
                position_ids.append(request_position_ids)
1032

1033
1034
1035
1036
1037
1038
1039
1040
                if not r.slots:
                    request_slots = [
                        s
                        for b in blocks
                        for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
                    ]
                else:
                    request_slots = r.slots
1041

1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
                request_slot_indices = torch.arange(
                    cache_length + cumulative_slot_tokens,
                    cache_length + cumulative_slot_tokens + input_length,
                    dtype=torch.int64,
                )

                slot_indices.append(request_slot_indices)

                # Update
                cumulative_slot_tokens += len(request_slots)
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077

            # Create tensor to slice into the kv tensor in prefill
            if sliding_window is not None:
                request_prefill_cache_indices = torch.arange(
                    cumulative_length + max(0, input_length - sliding_window),
                    cumulative_length + input_length,
                    dtype=torch.int64,
                )

            # Prefill logprobs is ignored if the request is done prefilling
            prefill_logprobs = r.prefill_logprobs and request_prefilling

            all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs

            if prefill_logprobs:
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

            if sliding_window is not None:
                prefill_cache_indices.append(request_prefill_cache_indices)

            ADAPTER_TO_INDEX = get_adapter_to_index()
1078
1079
1080
1081
1082
1083
            if ADAPTER_TO_INDEX:
                adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
                adapter_indices_list.append(
                    torch.full((next_chunk_length,), adapter_index)
                )
                adapter_set.add(adapter_index)
1084
1085
1086
1087

            # Update
            cumulative_length += next_chunk_length

1088
1089
1090
        if not all_prefill_logprobs and not no_prefill_logprobs:
            prefill_head_indices = []
            prefill_next_token_indices = []
1091

1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
            # Cumulative length
            cumulative_length = 0
            prefill_out_cumulative_length = 0

            for i, (
                r,
                input_length,
                request_prefilling,
            ) in enumerate(
                zip(
                    self.requests,
                    self.input_lengths,
                    self.prefilling_mask,
                )
            ):
                # Prefill logprobs is ignored if the request is done prefilling
                prefill_logprobs = r.prefill_logprobs and request_prefilling

                if prefill_logprobs:
                    prefill_head_indices.append(
                        torch.arange(
                            cumulative_length,
                            cumulative_length + input_length,
                            dtype=torch.int64,
                        )
                    )
                    prefill_next_token_indices.append(
                        prefill_out_cumulative_length + input_length - 1
                    )
                    prefill_out_cumulative_length += input_length
                else:
                    prefill_head_indices.append(
                        torch.tensor(
                            [cumulative_length + input_length - 1],
                            dtype=torch.int64,
                        )
                    )
                    prefill_next_token_indices.append(prefill_out_cumulative_length)
                    prefill_out_cumulative_length += 1

                # Update
                cumulative_length += input_length
1134
1135

        if len(self) > 1:
1136
1137
1138
1139
            if position_ids:
                position_ids = torch.cat(position_ids)
            if slot_indices:
                slot_indices = torch.cat(slot_indices)
1140
1141
1142
            if sliding_window is not None:
                prefill_cache_indices = torch.cat(prefill_cache_indices)
        else:
1143
1144
1145
1146
            if position_ids:
                position_ids = position_ids[0]
            if slot_indices:
                slot_indices = slot_indices[0]
1147
1148
1149
            if sliding_window is not None:
                prefill_cache_indices = prefill_cache_indices[0]

1150
1151
1152
1153
        if not has_triton():
            self.position_ids = position_ids.to(device)
            self.slot_indices = slot_indices.to(device)

1154
1155
1156
1157
1158
1159
1160
        self.prefill_cu_outlens = prefill_cu_outlens
        self.prefill_cache_indices = (
            prefill_cache_indices.to(device) if sliding_window is not None else None
        )

        if all_prefill_logprobs:
            prefill_head_indices = None
1161
            prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1
1162
        elif no_prefill_logprobs:
1163
            prefill_head_indices = self.cu_seqlen_prefill[1:] - 1
1164
1165
1166
1167
1168
1169
1170
1171
1172
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.cat(prefill_head_indices).to(device)
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )

        self.prefill_head_indices = prefill_head_indices
        self.prefill_next_token_indices = prefill_next_token_indices
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183

        if adapter_set:
            adapter_indices = torch.cat(adapter_indices_list).to(
                dtype=torch.int64, device=device
            )
            adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
        else:
            adapter_indices = torch.zeros_like(self.input_ids)
            adapter_segments = [0, len(adapter_indices)]
            adapter_segment_indices = [len(adapter_indices) - 1]

1184
1185
1186
        adapter_segments = torch.tensor(
            adapter_segments, dtype=torch.int32, device=device
        )
1187

1188
1189
1190
1191
1192
        self.adapter_meta = AdapterBatchMetadata(
            adapter_indices=adapter_indices,
            adapter_set=adapter_set,
            adapter_segments=adapter_segments,
            segment_indices=adapter_segment_indices,
1193
1194
1195
1196
1197
1198
        )

    def __len__(self):
        return len(self.requests)


1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
ADAPTER_LAYERS = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}


1211
1212
1213
class FlashCausalLM(Model):
    def __init__(
        self,
drbh's avatar
drbh committed
1214
        model_id: str,
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
        model_class,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        speculator: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
        lora_adapter_ids: Optional[list] = [],
        tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
        config_class: PreTrainedTokenizerBase = AutoConfig,
        default_dtype=torch.float16,
        aliases=None,
        # Used for Santacoder override of config
1227
1228
1229
        num_kv_heads: Optional[int] = None,
        # Deepseek V2 uses different QK and V dims.
        head_size: Optional[int] = None,
1230
        skip_special_tokens: bool = True,
1231
        kv_cache_dtype: Optional[torch.dtype] = None,
1232
        support_chunking: bool = True,
1233
    ):
Nicolas Patry's avatar
Nicolas Patry committed
1234
        self.quantize = quantize
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = default_dtype if dtype is None else dtype
        elif SYSTEM == "ipex":
            if hasattr(torch, "xpu") and torch.xpu.is_available():
                device = torch.device(f"xpu:{rank}")
                dtype = default_dtype if dtype is None else dtype
            else:
                device = torch.device("cpu")
                dtype = torch.bfloat16 if dtype is None else dtype
1246
                init_cpu_threads_env(rank_id=rank, world_size=world_size)
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
        else:
            raise NotImplementedError(f"{model_class} is only available on GPU")

        tokenizer = tokenizer_class.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )
        try:
            generation_config = GenerationConfig.from_pretrained(
                model_id, revision=revision, trust_remote_code=trust_remote_code
            )
            if isinstance(generation_config.eos_token_id, (list, set)):
                # TODO Huge hack
                tokenizer._eos_token_ids = set(generation_config.eos_token_id)
        except Exception:
            pass

        config = config_class.from_pretrained(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        config.quantize = quantize
        config.speculator = speculator

        torch.distributed.barrier(group=self.process_group)

1275
        weights_loader = get_loader(quantize, model_id, revision)
1276
1277
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(
1278
1279
1280
1281
1282
1283
            filenames,
            device,
            dtype,
            process_group=self.process_group,
            aliases=aliases,
            weights_loader=weights_loader,
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
        )

        prefix = ""
        model = model_class(prefix, config, weights)
        torch.distributed.barrier(group=self.process_group)

        # VLM models define the config we care about in their text_config
        text_config = getattr(config, "text_config", None)
        if text_config is not None:
            config = text_config
1294
1295
1296
1297
1298
1299

        if getattr(config, "sliding_window", None) is not None:
            set_sliding_window(config.sliding_window)
        else:
            config.sliding_window = None

1300
        self.num_layers = config.num_hidden_layers
1301
        self.num_heads = config.num_attention_heads // self.process_group.size()
1302
1303
        # Validation is done in the model itself
        if num_kv_heads is None:
1304
1305
            num_kv_heads = getattr(config, "num_key_value_heads", None)
            # GPT-2 workaround
1306
            if num_kv_heads is None:
1307
1308
1309
                num_kv_heads = getattr(config, "n_head", None)
        if num_kv_heads is None:
            raise ValueError("Cannot get the number of key/value heads")
1310
1311
1312
1313
1314
1315
        self.num_kv_heads = (
            num_kv_heads // self.process_group.size()
            if num_kv_heads > 1
            else num_kv_heads
        )
        assert self.num_kv_heads > 0
1316
1317

        if head_size is None:
Nicolas Patry's avatar
Nicolas Patry committed
1318
1319
1320
1321
1322
1323
            # Some models use GQA and different sizes for o_proj
            # and q_proj, that allows for that.
            if hasattr(config, "head_dim"):
                self.head_size = config.head_dim
            else:
                self.head_size = config.hidden_size // config.num_attention_heads
1324
1325
        else:
            self.head_size = head_size
1326

1327
        self.cuda_graphs = {}
1328
        self.kv_cache = []
1329
        self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
1330

1331
        if ATTENTION == "flashinfer":
Nicolas Patry's avatar
Nicolas Patry committed
1332
            from text_generation_server.layers.attention.flashinfer import (
1333
1334
                create_prefill_state,
                create_decode_state,
Nicolas Patry's avatar
Nicolas Patry committed
1335
                create_prefill_with_paged_kv_state,
1336
1337
1338
            )

            self.prefill_state = create_prefill_state(device=device)
Nicolas Patry's avatar
Nicolas Patry committed
1339
1340
1341
            self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
                device=device
            )
1342

Nicolas Patry's avatar
Nicolas Patry committed
1343
1344
1345
1346
1347
            self.decode_state = create_decode_state(
                device=device,
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
            )
1348

1349
        super().__init__(
drbh's avatar
drbh committed
1350
            model_id=model_id,
1351
            model=model,
1352
1353
1354
1355
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
1356
1357
            rank=rank,
            world_size=world_size,
1358
            sliding_window=config.sliding_window,
1359
            support_chunking=support_chunking,
1360
1361
1362
1363
1364
1365
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
    def max_past(self) -> int:
        return getattr(self.model, "max_past", None)

    def init_kv_cache(
        self,
        num_blocks: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        self.kv_cache = []
        empty_cache()
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
        self.kv_cache = [
            KVCache(
                num_blocks=num_blocks,
                num_heads=num_heads,
                head_size=head_size,
                dtype=dtype,
                device=device,
            )
            for _ in range(num_layers)
        ]
1390

1391
1392
1393
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
1394
        slots = torch.arange(bs, dtype=torch.int64, device=self.device)
Nicolas Patry's avatar
Nicolas Patry committed
1395
        input_lengths = [max_s] * bs
1396
        cache_lengths = [0] * bs
Nicolas Patry's avatar
Nicolas Patry committed
1397
1398
        input_lengths_tensor = (
            torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
1399
        )
1400
        cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
Nicolas Patry's avatar
Nicolas Patry committed
1401
1402
1403
1404
1405
1406
1407
1408
1409
        block_tables = torch.arange(
            max_bt, dtype=torch.int32, device=self.device
        ).repeat(bs)
        block_tables = block_tables.reshape((bs, max_bt))

        if ATTENTION == "flashinfer":
            block_tables = block_tables_to_ragged(
                block_tables=block_tables,
                input_lengths=input_lengths,
1410
                cache_lengths=cache_lengths,
1411
1412
1413
                input_lengths_tensor=input_lengths_tensor,
                cache_lengths_tensor=cache_lengths_tensor,
                max_current_length=max_s,
Nicolas Patry's avatar
Nicolas Patry committed
1414
1415
            )
            from text_generation_server.layers.attention.flashinfer import (
1416
1417
1418
1419
1420
1421
1422
1423
1424
                create_decode_state_cuda_graphs,
            )

            block_tables_ptr = torch.zeros(
                bs + 1, dtype=torch.int32, device=self.device
            )
            last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
            state = create_decode_state_cuda_graphs(
                device=input_ids.device,
Nicolas Patry's avatar
Nicolas Patry committed
1425
                block_tables=block_tables,
1426
1427
1428
1429
1430
1431
1432
1433
                block_tables_ptr=block_tables_ptr,
                last_page_len=last_page_len,
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
            )
        else:
            state = None

drbh's avatar
drbh committed
1434
1435
1436
1437
1438
1439
1440
1441
        if (
            hasattr(self.model, "config")
            and hasattr(self.model.config, "model_type")
            and self.model.config.model_type == "qwen2_vl"
        ):
            if position_ids.dim() == 1:
                position_ids = self.model.get_position_ids(input_ids)

1442
1443
1444
1445
1446
1447
1448
1449
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "kv_cache": self.kv_cache,
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths_tensor,
1450
            "cache_lengths": cache_lengths_tensor,
1451
1452
1453
1454
            "state": state,
            "graph": graph,
        }

1455
1456
        torch.cuda.synchronize()
        # Run once outside to warmup
1457
        with self._forward_context(
1458
            block_tables=block_tables,
1459
            cu_seqlen_prefill=None,
Nicolas Patry's avatar
Nicolas Patry committed
1460
            input_lengths_tensor=input_lengths_tensor,
1461
            state=state,
1462
            cache_lengths_tensor=cache_lengths_tensor,
1463
        ):
1464
1465
            seqlen = Seqlen(
                input_lengths=input_lengths_tensor,
1466
                cache_lengths=cache_lengths_tensor,
1467
1468
1469
1470
                cu_seqlen_q=None,
                max_q=1,
                max_k=max_s,
            )
1471
            self.model.forward(
1472
1473
1474
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
1475
                kv_cache=self.kv_cache,
1476
1477
                block_tables=block_tables,
                slots=slots,
1478
                seqlen=seqlen,
1479
                max_s=max_s,
1480
                prefill_cache_indices=None,
1481
1482
                lm_head_indices=None,
            )
1483
            del seqlen
1484
1485
1486
1487

            torch.cuda.synchronize()

            with torch.cuda.graph(graph, pool=MEM_POOL):
1488
1489
                seqlen = Seqlen(
                    input_lengths=input_lengths_tensor,
1490
                    cache_lengths=cache_lengths_tensor,
1491
1492
1493
1494
                    cu_seqlen_q=None,
                    max_q=1,
                    max_k=max_s,
                )
1495
1496
1497
1498
1499
1500
1501
                logits, speculative_logits = self.model.forward(
                    input_ids=input_ids,
                    position_ids=position_ids,
                    cu_seqlen_prefill=None,
                    kv_cache=self.kv_cache,
                    block_tables=block_tables,
                    slots=slots,
1502
                    seqlen=seqlen,
1503
1504
1505
1506
1507
1508
                    max_s=max_s,
                    prefill_cache_indices=None,
                    lm_head_indices=None,
                )
                self.cuda_graphs[bs]["logits"] = logits
                self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
1509
1510
        torch.cuda.synchronize()

1511
1512
1513
1514
1515
1516
    def warmup(
        self,
        batch: FlashCausalLMBatch,
        max_input_tokens: Optional[int],
        max_total_tokens: Optional[int],
    ):
1517
        # The warmup batch is the biggest batch we could ever receive
1518
        self.kv_cache = []
Nicolas Patry's avatar
Nicolas Patry committed
1519
1520
        empty_cache()

1521
1522
1523
1524
1525
1526
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
        dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

1527
        try:
1528
1529
            self.init_kv_cache(
                batch.num_blocks,
1530
1531
1532
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
1533
                self.kv_cache_dtype,
1534
1535
                self.device,
            )
1536
            batch_num_blocks = batch.num_blocks
fxmarty's avatar
fxmarty committed
1537
1538
1539

            if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
                torch.cuda.tunable.tuning_enable(False)
1540
            _, _batch, _ = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
1541
        except torch.cuda.OutOfMemoryError as e:
1542
            raise RuntimeError(
1543
                f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. "
1544
                f"You need to decrease `--max-batch-prefill-tokens`"
1545
            ) from e
1546

Nicolas Patry's avatar
Nicolas Patry committed
1547
        synchronize(self.device)
1548

Nicolas Patry's avatar
Nicolas Patry committed
1549
        free_memory = get_free_memory(self.device, MEMORY_FRACTION)
1550
1551

        num_blocks = (
1552
            # Leave 5% for some wiggle room
1553
            int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size)
1554
            # Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
drbh's avatar
drbh committed
1555
            + batch_num_blocks
1556
1557
        )

1558
        log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
1559
1560
1561
1562
1563
1564
1565
1566
1567
        if max_total_tokens is None:
            if get_support_chunking():
                model_max_length = self.tokenizer.model_max_length
                max_input_tokens = (
                    min((num_blocks * BLOCK_SIZE - 1), model_max_length)
                    if max_input_tokens is None
                    else max_input_tokens
                )
                max_total_tokens = num_blocks * BLOCK_SIZE
1568

1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
            else:
                max_total_tokens = sum(batch.cache_lengths)
                max_input_tokens = (
                    max_total_tokens - 1
                    if max_input_tokens is None
                    else max_input_tokens
                )

        del _batch, batch
        self.kv_cache = []
        empty_cache()
1580

1581
        self.init_kv_cache(
1582
1583
1584
1585
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
1586
            self.kv_cache_dtype,
1587
1588
1589
            self.device,
        )

fxmarty's avatar
fxmarty committed
1590
1591
1592
1593
1594
        if SYSTEM == "rocm":
            if (
                os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
                or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
            ):
1595
1596
                torch.cuda.tunable.enable()

fxmarty's avatar
fxmarty committed
1597
1598
1599
1600
1601
1602
1603
1604
                if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
                    torch.cuda.tunable.tuning_enable(True)

                if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None:
                    tuning_sequences = [
                        int(val)
                        for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
                    ]
1605
                elif CUDA_GRAPHS is not None:
fxmarty's avatar
fxmarty committed
1606
                    tuning_sequences = CUDA_GRAPHS
1607
                else:
1608
                    tuning_sequences = [1, 2, 3, 4, 5, 6, 7]
fxmarty's avatar
fxmarty committed
1609
1610
1611

                tunableop_filepath = os.path.join(
                    HUGGINGFACE_HUB_CACHE,
drbh's avatar
drbh committed
1612
                    f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
fxmarty's avatar
fxmarty committed
1613
1614
                )

1615
1616
                log_master(
                    logger.info,
1617
1618
1619
1620
1621
                    f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
                )

                torch.cuda.tunable.set_filename(
                    tunableop_filepath, insert_device_ordinal=False
fxmarty's avatar
fxmarty committed
1622
1623
1624
                )

                if os.path.isfile(tunableop_filepath):
1625
1626
1627
                    log_master(
                        logger.info,
                        f"The file {tunableop_filepath} already exists and will be reused.",
fxmarty's avatar
fxmarty committed
1628
1629
1630
1631
1632
1633
                    )
                    torch.cuda.tunable.read_file(tunableop_filepath)

                os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)

                for seqlen in tuning_sequences:
1634
                    log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
fxmarty's avatar
fxmarty committed
1635
1636
                    self.tunableop_warmup(seqlen)
                    torch.cuda.tunable.write_file(tunableop_filepath)
1637
1638
                if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
                    torch.cuda.tunable.tuning_enable(False)
fxmarty's avatar
fxmarty committed
1639
            else:
1640
1641
1642
                log_master(
                    logger.info,
                    "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
fxmarty's avatar
fxmarty committed
1643
1644
                )

1645
        if CUDA_GRAPHS:
1646
            try:
1647
1648
1649
                log_master(
                    logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
                )
1650
                # Warmup cuda graphs
1651
                for bs in CUDA_GRAPHS:
1652
                    if self.speculate is None or self.speculate + 1 <= bs:
1653
                        self.cuda_graph_warmup(bs, max_total_tokens, max_total_tokens)
OlivierDehaene's avatar
OlivierDehaene committed
1654
            except torch.cuda.OutOfMemoryError:
1655
                logger.exception("Decode cuda graph warmup failed")
1656
        else:
1657
1658
1659
            log_master(
                logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
            )
1660

1661
1662
1663
        assert max_input_tokens is not None
        assert max_total_tokens is not None
        return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
1664

fxmarty's avatar
fxmarty committed
1665
1666
1667
1668
1669
    def tunableop_warmup(self, seqlen: int):
        input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
        slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)

fxmarty's avatar
fxmarty committed
1670
1671
        # Dummy value, some models (starcoder2) don't accept `None`.
        input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
1672
1673
1674
        cache_lengths_tensor = torch.zeros(
            seqlen, dtype=torch.int32, device=self.device
        )
1675
1676
1677
        cu_seqlen_prefill = torch.tensor(
            [0, seqlen], device=self.device, dtype=torch.int32
        )
1678
        max_s = seqlen
1679
1680
        seqlen = Seqlen(
            input_lengths=input_lengths,
1681
            cache_lengths=cache_lengths_tensor,
1682
1683
1684
1685
            cu_seqlen_q=cu_seqlen_prefill,
            max_q=1,
            max_k=seqlen,
        )
fxmarty's avatar
fxmarty committed
1686

fxmarty's avatar
fxmarty committed
1687
1688
1689
1690
        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
1691
            cu_seqlen_prefill=cu_seqlen_prefill,
1692
            kv_cache=self.kv_cache,
fxmarty's avatar
fxmarty committed
1693
            block_tables=None,
1694
            seqlen=seqlen,
fxmarty's avatar
fxmarty committed
1695
            slots=slots,
1696
            max_s=max_s,
fxmarty's avatar
fxmarty committed
1697
            lm_head_indices=None,
1698
            prefill_cache_indices=None,
fxmarty's avatar
fxmarty committed
1699
1700
        )

1701
    def forward(
drbh's avatar
drbh committed
1702
        self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
1703
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1704
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
1705
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
1706
1707
1708
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
1709
            kv_cache = self.kv_cache
OlivierDehaene's avatar
OlivierDehaene committed
1710
1711
1712
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
1713
            max_s = batch.max_current_length
OlivierDehaene's avatar
OlivierDehaene committed
1714
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
1715
1716
1717

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
1718
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
1719
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
1720
1721
1722
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1723
1724
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
1725
1726
1727
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
1728
1729
1730
1731

            # Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,
            # then update the slots with the additional indices to ensure we're grabbing the ones that have been
            # allocated
1732
1733
1734
            slot_indices = (
                batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
1735
            slots = batch.slots[slot_indices]
1736

OlivierDehaene's avatar
OlivierDehaene committed
1737
1738
1739
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
1740
1741
            cache_lengths_tensor = (
                batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
Nicolas Patry's avatar
Nicolas Patry committed
1742
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1743
1744

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
1745
1746
1747
1748
1749
1750
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
1751
1752
1753
1754
1755
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
1756
1757
1758
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
1759
            kv_cache = self.kv_cache
OlivierDehaene's avatar
OlivierDehaene committed
1760
1761
1762
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
1763
1764
            cache_lengths_tensor = batch.cache_lengths_tensor
            max_s = batch.max_current_length
OlivierDehaene's avatar
OlivierDehaene committed
1765
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
1766

1767
1768
1769
1770
1771
1772
        if cu_seqlen_prefill is None and self.max_past() is not None:
            # In decode, not prefill, we're actually overwriting the KV-cache
            # in a circular buffer mode.
            # This makes sure the max_s for the decode pass is correct.
            max_s = min(self.max_past(), max_s)

1773
        bs = input_ids.shape[0]
OlivierDehaene's avatar
OlivierDehaene committed
1774
1775
1776
1777
1778
1779
1780
1781
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None

        if cu_seqlen_prefill is not None or cuda_graph is None:
1782
            if ATTENTION == "flashinfer":
Nicolas Patry's avatar
Nicolas Patry committed
1783
1784
1785
                block_tables = block_tables_to_ragged(
                    block_tables=block_tables,
                    input_lengths=batch.input_lengths,
1786
                    cache_lengths=batch.cache_lengths,
1787
1788
1789
                    input_lengths_tensor=batch.input_lengths_tensor,
                    cache_lengths_tensor=batch.cache_lengths_tensor,
                    max_current_length=batch.max_current_length,
Nicolas Patry's avatar
Nicolas Patry committed
1790
                )
1791
            with self._forward_context(
1792
                block_tables=block_tables,
1793
                cu_seqlen_prefill=cu_seqlen_prefill,
1794
                input_lengths_tensor=input_lengths,
1795
                cache_lengths_tensor=cache_lengths_tensor,
1796
            ):
1797
1798
                seqlen = Seqlen(
                    input_lengths=input_lengths,
1799
                    cache_lengths=cache_lengths_tensor,
1800
                    cu_seqlen_q=cu_seqlen_prefill,
1801
1802
                    max_q=batch.max_input_length,
                    max_k=batch.max_current_length,
1803
                )
1804
1805
1806
1807
1808
1809
1810
                logits, speculative_logits = self.model.forward(
                    input_ids=input_ids,
                    position_ids=position_ids,
                    cu_seqlen_prefill=cu_seqlen_prefill,
                    kv_cache=kv_cache,
                    block_tables=block_tables,
                    slots=slots,
1811
                    seqlen=seqlen,
1812
1813
1814
1815
1816
1817
1818
1819
                    max_s=max_s,
                    prefill_cache_indices=batch.prefill_cache_indices,
                    lm_head_indices=lm_head_indices,
                    adapter_data=adapter_data,
                )
                if batch.prefill_cache_indices is not None:
                    batch.prefill_cache_indices = None
                return logits, speculative_logits
1820
1821
1822
1823

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
drbh's avatar
drbh committed
1824
        cuda_graph["position_ids"][: position_ids.shape[-1]] = position_ids
Nicolas Patry's avatar
Nicolas Patry committed
1825
1826
1827
1828
        if ATTENTION == "flashinfer":
            block_tables = block_tables_to_ragged(
                block_tables=block_tables,
                input_lengths=batch.input_lengths,
1829
                cache_lengths=batch.cache_lengths,
1830
1831
1832
                input_lengths_tensor=batch.input_lengths_tensor,
                cache_lengths_tensor=batch.cache_lengths_tensor,
                max_current_length=batch.max_current_length,
Nicolas Patry's avatar
Nicolas Patry committed
1833
            )
1834
            # assert block_tables.shape[0] >= slots.shape[0]
Nicolas Patry's avatar
Nicolas Patry committed
1835
1836
1837
1838
1839
            cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
        else:
            cuda_graph["block_tables"][
                : block_tables.shape[0], : block_tables.shape[1]
            ] = block_tables
1840
1841
1842
1843

        # XXX: This is working only because block 0 is reserved for the healthcheck
        # so it doesn't matter if we override it with bogus values.
        cuda_graph["slots"].fill_(0)
1844
1845
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
1846
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
1847
1848
1849
1850
        cuda_graph["cache_lengths"].zero_()
        cuda_graph["cache_lengths"][
            : cache_lengths_tensor.shape[0]
        ] = cache_lengths_tensor
1851

1852
        with self._forward_context(
Nicolas Patry's avatar
Nicolas Patry committed
1853
            block_tables=cuda_graph["block_tables"],
1854
            cu_seqlen_prefill=None,
Nicolas Patry's avatar
Nicolas Patry committed
1855
            input_lengths_tensor=cuda_graph["input_lengths"],
1856
            cache_lengths_tensor=cuda_graph["cache_lengths"],
1857
            state=cuda_graph["state"],
1858
1859
1860
1861
        ):
            # Replay the graph
            cuda_graph["graph"].replay()

1862
        # Slice output to the correct shape
1863
1864
1865
1866
1867
1868
1869
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
1870
1871
1872
1873

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
1874
1875
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
1876
1877
1878
1879
        prefill = batch.prefilling
        if prefill:
            batch.prepare_for_prefill()

1880
        prefill_logprobs = batch.prefill_next_token_indices is not None
1881

drbh's avatar
drbh committed
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
        # Update adapter indices for speculative tokens (if present)
        adapter_meta = batch.adapter_meta
        if batch.speculative_ids is not None:
            B, speculative_length = batch.speculative_ids.shape
            new_length = speculative_length + 1
            adapter_indices = (
                adapter_meta.adapter_indices.unsqueeze(-1)
                .expand(B, new_length)
                .reshape(-1)
            )
            adapter_segments = adapter_meta.adapter_segments * new_length
            adapter_meta = AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_meta.adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_meta.segment_indices,
            )

        # Assign pointers to adapter weights
        # TODO(travis): don't update this if indices haven't changed
        adapter_data = AdapterBatchData.from_meta(
            adapter_meta,
            self.layer_to_adapter_weights,
            prefill,
            batch.prefill_head_indices,
        )

        out, speculative_logits = self.forward(batch, adapter_data)
1910

1911
1912
        if prefill:
            next_token_logits = (
1913
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
1914
            )
Nicolas Patry's avatar
Nicolas Patry committed
1915
1916
            if speculative_logits is not None:
                speculative_logits = (
OlivierDehaene's avatar
OlivierDehaene committed
1917
1918
1919
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
Nicolas Patry's avatar
Nicolas Patry committed
1920
                )
1921
1922
1923
1924
            if len(batch) > 1 and prefill_logprobs:
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
1925
        else:
1926
            prefill_logprobs = None
1927
1928
            next_token_logits = out

1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
        finished_prefilling = True
        next_chunk_lengths = []
        current_prefilling_mask = batch.prefilling_mask
        if prefill:
            if get_support_chunking():
                next_prefilling_mask = []
                # Budget in tokens for the next batch
                # We remove (len(batch) - 1) to always have enough space for at least a single decode
                # for the remaining requests -1 because the first request does not need to be removed from the budget
                # (ex: you have one request in the batch, you want it to take the full budget not budget -1)
                batch_budget = get_max_prefill_tokens() - (len(batch) - 1)
                # We reverse to prioritize older requests
                # zip() is not reversible so reverse the underlying lists instead
                for cache_length, input_length, prompt_length in zip(
                    reversed(batch.cache_lengths),
                    reversed(batch.input_lengths),
                    reversed(batch.prompt_lengths),
                ):
                    remaining_prefill_tokens = max(
                        prompt_length - cache_length - input_length, 0
                    )
                    if remaining_prefill_tokens > 0:
                        next_chunk_length = max(
                            min(remaining_prefill_tokens, batch_budget), 1
                        )
                        batch_budget -= next_chunk_length
                        finished_prefilling = False
                        next_prefilling_mask.append(True)
                    else:
                        # FIXME: use true number of accepted tokens instead of 1
                        # Since speculation will be turned off, this is always true
                        next_chunk_length = 1
                        next_prefilling_mask.append(False)
                    next_chunk_lengths.append(next_chunk_length)

                # Reverse back the obtained values²
                next_chunk_lengths.reverse()
                next_prefilling_mask.reverse()
            else:
                # The model does not support chunking
                # We know we only do a single prefill
                finished_prefilling = True
                next_prefilling_mask = [False] * len(batch)

            batch.prefilling = not finished_prefilling
            batch.prefilling_mask = next_prefilling_mask

Nicolas Patry's avatar
Nicolas Patry committed
1976
        speculate = get_speculate()
OlivierDehaene's avatar
OlivierDehaene committed
1977
1978
1979
1980
1981
1982
1983
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
1984
            batch.all_input_ids_tensor[:, : batch.max_current_length],
OlivierDehaene's avatar
OlivierDehaene committed
1985
            next_token_logits,
Nicolas Patry's avatar
Nicolas Patry committed
1986
            speculate,
OlivierDehaene's avatar
OlivierDehaene committed
1987
1988
            batch.speculative_ids,
            speculative_logits,
1989
1990
        )

Nicolas Patry's avatar
Nicolas Patry committed
1991
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
1992
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1993
1994
        )

1995
1996
1997
        # Since we are done prefilling, all the tensors that were concatenating values for all the requests
        # instantly become of shape [BATCH_SIZE]
        if prefill and finished_prefilling:
1998
            indices = batch.cu_seqlen_prefill[1:] - 1
drbh's avatar
drbh committed
1999
            batch.position_ids = batch.position_ids[(..., indices)]
2000
2001
2002
2003
            batch.slot_indices = batch.slot_indices[indices]
            batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
                indices
            ]
2004

2005
        # Zipped iterator
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
        iterator = zip(
            batch.requests,
            batch.prompt_lengths,
            batch.cache_lengths,
            batch.input_lengths,
            batch.all_input_ids,
            accepted_ids,
            current_prefilling_mask,
            batch.prefilling_mask,
        )
2016

2017
2018
2019
2020
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

2021
        # For each member of the batch
2022
        # Cumulative length
2023
2024
2025
        cu_accepted_ids = torch.nn.functional.pad(
            torch.cumsum(accepted_ids, dim=0), (1, 0)
        )
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
        cumulative_length = 0
        for i, (
            request,
            prompt_length,
            cache_length,
            input_length,
            all_input_ids,
            n_accepted_ids,
            request_was_prefilling,
            request_is_prefilling,
        ) in enumerate(iterator):
            # Used to gather prefill logprobs
            # Copy batch.all_input_ids_tensor to prefill_token_indices
            if request.prefill_logprobs and request_was_prefilling:
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
2043

2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
                # Logprobs generated by the model are for the next token
                # So we need to translate the id tensor by 1
                ids = batch.all_input_ids_tensor[
                    i, cache_length + 1 : cache_length + input_length + 1
                ]
                if len(batch) > 1:
                    prefill_tokens_indices[out_start_index:out_end_index] = ids
                else:
                    # Set prefill_tokens_indices to the correct slice
                    prefill_tokens_indices = ids

2055
2056
            # If the device does not support triton, we copy one by one
            if not request_is_prefilling and not has_triton():
2057
                # Only save tokens if we are done prefilling for this request
2058
2059
2060
2061
2062
2063
2064
                batch.all_input_ids_tensor[
                    i,
                    batch.cache_lengths_tensor[i]
                    + batch.input_lengths[i] : batch.cache_lengths_tensor[i]
                    + batch.input_lengths[i]
                    + accepted_ids[i],
                ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
2065
2066
            cumulative_length += input_length

2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
        # If the device support triton, we can use a fused kernel
        if has_triton():
            copy_next_input_ids_inplace(
                speculate + 1,
                batch.all_input_ids_tensor,
                batch.cache_lengths_tensor,
                batch.input_lengths_tensor,
                batch.prompt_lengths_tensor,
                next_input_ids,
                cu_accepted_ids,
            )

drbh's avatar
drbh committed
2079
        # Update values
2080
2081
        # These values can be updated without a GPU -> CPU sync
        if not prefill or (prefill and finished_prefilling):
2082
            batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
2083
            batch.speculative_ids = speculative_ids
2084
2085
2086
            batch.position_ids += accepted_ids
            batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
            batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
2087
            batch.slot_indices += accepted_ids
2088

2089
        if prefill and prefill_logprobs:
2090
2091
2092
            # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size))
            torch.log_softmax(out, -1, out=out)
            prefill_logprobs_tensor = out
2093
2094
2095
2096
2097
2098
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
        # Does a GPU <-> CPU sync internally
        if prefill and finished_prefilling:
            # adjust segment lengths to account for all request lengths being 1 during decoding
            adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
            batch.adapter_meta.adapter_segments = torch.tensor(
                adapter_segments,
                dtype=torch.int32,
                device=batch.adapter_meta.adapter_segments.device,
            )

2109
2110
        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
2111
        next_token_ids = next_input_ids.tolist()
2112
        accepted_ids = accepted_ids.tolist()
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153

        # Update values if we need to continue prefilling
        # This represents the `else` case of the `Update values` if above
        # but since this require the `next_token_ids` to be on CPU, it is better to do it here
        if prefill and not finished_prefilling:
            # Speculation must be ignored while we prefill even with chunking
            # it simplifies everything
            assert batch.speculative_ids is None

            all_postfix_ids = []
            for i, (
                request_prefilling,
                next_token_id,
                all_input_ids,
                cache_length,
                input_length,
                next_chunk_length,
            ) in enumerate(
                zip(
                    batch.prefilling_mask,
                    next_token_ids,
                    batch.all_input_ids,
                    batch.cache_lengths,
                    batch.input_lengths,
                    next_chunk_lengths,
                )
            ):
                if request_prefilling:
                    next_cache_length = cache_length + input_length
                    # Get new prompt IDs to prefill
                    postfix_ids = all_input_ids[
                        next_cache_length : next_cache_length + next_chunk_length
                    ]
                else:
                    # This request is done prefilling, the new id is the one selected the sampling method
                    postfix_ids = [next_token_id]

                all_postfix_ids.append(postfix_ids)

            batch.input_ids = all_postfix_ids

2154
        start_decode = time.time_ns()
2155

2156
2157
2158
2159
        # Results
        generations: List[Generation] = []
        stopped = True

2160
2161
2162
        # Zipped iterator
        iterator = zip(
            batch.requests,
2163
2164
            batch.prompt_lengths,
            batch.cache_lengths,
2165
            batch.input_lengths,
2166
2167
            batch.prefix_offsets,
            batch.read_offsets,
2168
2169
            batch.stopping_criterias,
            batch.all_input_ids,
2170
2171
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
2172
            batch.top_n_tokens,
2173
2174
            current_prefilling_mask,
            batch.prefilling_mask,
Nicolas Patry's avatar
Nicolas Patry committed
2175
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
2176
2177
            batch_top_token_ids,
            batch_top_token_logprobs,
2178
2179
        )

2180
2181
        # Reset max_input_length
        batch.max_input_length = 0
2182
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
2183
        index = 0
2184
2185
        for i, (
            request,
2186
2187
            prompt_length,
            cache_length,
2188
            input_length,
2189
2190
            prefix_offset,
            read_offset,
2191
2192
            stopping_criteria,
            all_input_ids,
2193
2194
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
2195
            top_n_tokens,
2196
2197
            request_was_prefilling,
            request_is_prefilling,
Nicolas Patry's avatar
Nicolas Patry committed
2198
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
2199
2200
            top_token_ids,
            top_token_logprobs,
2201
        ) in enumerate(iterator):
2202
2203
2204
2205
2206
            # Compute logprobs first as, even though we might skip the token,
            # it can still be required to compute the logprobs
            # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
            # this state to be stable
            if request.id % self.world_size == self.rank:
2207
                # Prefill
2208
                if request_was_prefilling and request.prefill_logprobs:
2209
2210
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
                    if not request_is_prefilling:
                        # The request is dones prefilling, meaning that we started generating new tokens
                        # The last logprob is a logprob for a generated token that was not part of the prompt
                        # We need to remove it
                        out_end_index -= 1

                    request_prefill_logprobs = prefill_logprobs[
                        out_start_index:out_end_index
                    ]
                    # Logprobs generated by the model are for the next token
                    # So we need to translate the id tensor by 1
                    prefill_token_ids = all_input_ids[
                        cache_length + 1 : cache_length + input_length + 1
                    ]

                    past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]

                    if past_prefill_logprob_tokens is None:
                        # add nan for cached prompt tokens/first token
                        request_prefill_logprobs = [float("nan")] * (
                            cache_length + 1
                        ) + request_prefill_logprobs
                        prefill_token_ids = (
                            all_input_ids[: cache_length + 1] + prefill_token_ids
                        )
2236

2237
                    prefill_texts = self.tokenizer.batch_decode(
2238
                        prefill_token_ids,
2239
2240
2241
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
2242

2243
2244
                    prefill_logprob_tokens = Tokens(
                        prefill_token_ids,
OlivierDehaene's avatar
OlivierDehaene committed
2245
2246
2247
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
2248
                    )
2249
2250
2251
2252
2253
2254
                    if past_prefill_logprob_tokens is not None:
                        prefill_logprob_tokens = (
                            past_prefill_logprob_tokens + prefill_logprob_tokens
                        )

                    batch.prefill_logprob_tokens[i] = prefill_logprob_tokens
2255
                else:
2256
2257
2258
2259
2260
2261
2262
2263
                    batch.prefill_logprob_tokens[i] = None

            # If it is, the tokens we decoded should be ignored
            if request_is_prefilling:
                # Make sure that we do not stop as even though this request did not create a token, it is still
                # processing
                stopped = False
                new_input_length = next_chunk_lengths[i]
2264
                new_cache_length = cache_length + input_length
2265
            else:
2266
2267
                new_input_length = 1
                new_cache_length = cache_length + input_length + n_accepted_ids - 1
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
                # Append next token to all tokens
                next_token_texts = []
                left = 0

                if n_accepted_ids > 1:
                    log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")

                current_stopped = False
                for j in range(index, index + n_accepted_ids):
                    # Generated token
                    next_token_id = next_token_ids[j]
                    all_input_ids.append(next_token_id)
                    next_token_text, prefix_offset, read_offset = self.decode_token(
                        all_input_ids,
                        prefix_offset,
                        read_offset,
                    )
                    next_token_texts.append(next_token_text)

                    stop, reason = stopping_criteria(
                        next_token_id,
                        next_token_text,
                    )

                    if stop:
                        left = index + n_accepted_ids - j - 1
                        current_stopped = True
                        break
                    else:
                        current_stopped = False
                stopped = stopped and current_stopped

                _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
                _next_token_logprobs = next_token_logprobs[
                    index : index + n_accepted_ids - left
                ]

                # Shard generations
                # All generations will be appended in the rust sharded client
                if request.id % self.world_size == self.rank:
                    if stop:
                        # Decode generated tokens
                        output_text, _, _ = self.decode_token(
                            all_input_ids,
                            prefix_offset=len(all_input_ids)
                            - stopping_criteria.current_tokens
                            - 1,
                            read_offset=len(all_input_ids)
                            - stopping_criteria.current_tokens,
                            skip_special_tokens=True,
Nicolas Patry's avatar
Nicolas Patry committed
2318
                        )
2319
2320
2321
2322
2323
                        generated_text = GeneratedText(
                            output_text,
                            stopping_criteria.current_tokens,
                            reason,
                            seed if do_sample else None,
Nicolas Patry's avatar
Nicolas Patry committed
2324
                        )
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
                    else:
                        generated_text = None

                    if top_n_tokens > 0:
                        all_top_tokens = []
                        for top_token_ids, top_token_logprobs in zip(
                            top_token_ids, top_token_logprobs
                        ):
                            toptoken_texts = self.tokenizer.batch_decode(
                                top_token_ids,
                                clean_up_tokenization_spaces=False,
                                skip_special_tokens=False,
                            )
                            special_toptokens = [
                                token_id in self.all_special_ids
                                for token_id in top_token_ids
                            ]
                            top_tokens = Tokens(
                                top_token_ids,
                                top_token_logprobs,
                                toptoken_texts,
                                special_toptokens,
                            )
                            all_top_tokens.append(top_tokens)
                        top_tokens = all_top_tokens
                    else:
                        top_tokens = None

                    generation = Generation(
                        request.id,
                        batch.prefill_logprob_tokens[i],
                        Tokens(
                            _next_token_ids,
                            _next_token_logprobs,
                            next_token_texts,
                            [nid in self.all_special_ids for nid in _next_token_ids],
                        ),
                        generated_text,
                        top_tokens,
                    )
2365

2366
                    generations.append(generation)
2367

2368
2369
2370
2371
2372
2373
2374
2375
                # accept each new token for this specific request since we may
                # have more than one new token per request with speculative decoding
                for next_token_id in _next_token_ids:
                    batch.next_token_chooser = (
                        batch.next_token_chooser.advance_grammar_single(
                            i, next_token_id
                        )
                    )
drbh's avatar
drbh committed
2376

2377
            # Update values
2378
            index += n_accepted_ids
2379
2380
2381
2382
            batch.cache_lengths[i] = new_cache_length
            batch.max_input_length = max(batch.max_input_length, new_input_length)
            batch.input_lengths[i] = new_input_length
            current_length = new_cache_length + new_input_length
2383
2384
            batch.max_current_length = max(batch.max_current_length, current_length)

2385
2386
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
2387
2388
            batch.all_input_ids[i] = all_input_ids

2389
2390
        if stopped:
            # No need to return a batch if we know that all requests stopped
2391
2392
2393
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
2394

2395
2396
2397
2398
2399
2400
2401
        if prefill and finished_prefilling:
            # We do not need prefill tensors anymore
            batch.cu_seqlen_prefill = None
            batch.prefill_cache_indices = None
            batch.prefill_cu_outlens = None
            batch.prefill_head_indices = None
            batch.prefill_next_token_indices = None
2402

2403
2404
2405
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)
2406
2407
2408
2409
2410
2411

    def _forward_context(
        self,
        *,
        block_tables: torch.Tensor,
        cu_seqlen_prefill: Optional[torch.Tensor],
Nicolas Patry's avatar
Nicolas Patry committed
2412
        input_lengths_tensor: torch.Tensor,
2413
        cache_lengths_tensor: torch.Tensor,
2414
2415
        state: Optional[Any] = None,
    ) -> ContextManager:
2416
        if ATTENTION != "flashinfer":
2417
2418
            return nullcontext()

Nicolas Patry's avatar
Nicolas Patry committed
2419
        from text_generation_server.layers.attention.flashinfer import (
2420
            use_decode_state,
Nicolas Patry's avatar
Nicolas Patry committed
2421
            use_prefill_with_paged_kv_state,
2422
2423
2424
        )

        if cu_seqlen_prefill is not None:
Nicolas Patry's avatar
Nicolas Patry committed
2425
2426
2427
2428
2429
            return use_prefill_with_paged_kv_state(
                state=(
                    state if state is not None else self.prefill_with_paged_kv_state
                ),
                block_tables=block_tables,
2430
                cu_seqlens=cu_seqlen_prefill,
2431
                input_lengths=input_lengths_tensor + cache_lengths_tensor,
2432
2433
2434
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
Nicolas Patry's avatar
Nicolas Patry committed
2435
                page_size=BLOCK_SIZE,
2436
2437
                dtype=self.dtype,
                window_left=self.sliding_window,
2438
2439
            )
        else:
Nicolas Patry's avatar
Nicolas Patry committed
2440
            assert input_lengths_tensor is not None
2441
2442
            return use_decode_state(
                state=state if state is not None else self.decode_state,
2443
                input_lengths=input_lengths_tensor + cache_lengths_tensor,
Nicolas Patry's avatar
Nicolas Patry committed
2444
                block_tables=block_tables,
2445
2446
2447
2448
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                page_size=BLOCK_SIZE,
2449
                kv_cache_dtype=self.kv_cache_dtype,
2450
2451
                dtype=self.dtype,
                window_left=self.sliding_window,
2452
            )