flash_causal_lm.py 88.9 KB
Newer Older
1
from contextlib import nullcontext
2
import math
3
import os
4
import time
5
6
7
import torch
import torch.distributed

8
9
import numpy as np

10
from loguru import logger
11
12
from dataclasses import dataclass
from opentelemetry import trace
13
14
15
16
17
18
from transformers import (
    PreTrainedTokenizerBase,
    AutoConfig,
    AutoTokenizer,
    GenerationConfig,
)
19
20
21
22
23
24
25
26
27
28
29
from typing import (
    Any,
    ContextManager,
    Iterable,
    Optional,
    Tuple,
    List,
    Type,
    Dict,
    Union,
)
fxmarty's avatar
fxmarty committed
30

drbh's avatar
drbh committed
31
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
fxmarty's avatar
fxmarty committed
32
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
Daniël de Kok's avatar
Daniël de Kok committed
33
from text_generation_server.utils.chunks import concat_text_chunks
Nicolas Patry's avatar
Nicolas Patry committed
34
from text_generation_server.utils.import_utils import SYSTEM
OlivierDehaene's avatar
OlivierDehaene committed
35
from text_generation_server.models import Model
36
from text_generation_server.utils.log import log_master
37
38
39
40
from text_generation_server.utils.prefill_chunking import (
    get_support_chunking,
    get_max_prefill_tokens,
)
41
from text_generation_server.utils.tokens import batch_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
42
from text_generation_server.utils.speculate import get_speculate
43
44
45
46
47
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
)
48
49
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
50
    Tokens,
51
52
53
54
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
Nicolas Patry's avatar
Nicolas Patry committed
55
56
from text_generation_server.models.globals import (
    MEM_POOL,
57
    ATTENTION,
58
    BLOCK_SIZE,
Nicolas Patry's avatar
Nicolas Patry committed
59
    CUDA_GRAPHS,
60
    TGI_WIGGLE_ROOM,
Nicolas Patry's avatar
Nicolas Patry committed
61
62
    get_adapter_to_index,
)
63
from text_generation_server.layers.attention import KVCache, Seqlen
64
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
65
from text_generation_server.utils.dist import MEMORY_FRACTION
66
from text_generation_server.utils.quantization import get_loader
drbh's avatar
drbh committed
67
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
68

Nicolas Patry's avatar
Nicolas Patry committed
69
from text_generation_server.utils.import_utils import (
Nicolas Patry's avatar
Nicolas Patry committed
70
71
72
    empty_cache,
    synchronize,
    get_free_memory,
Nicolas Patry's avatar
Nicolas Patry committed
73
74
)

Nicolas Patry's avatar
Nicolas Patry committed
75
76
tracer = trace.get_tracer(__name__)

77
78
79
80
81
82
83
84
85
86
87
88
89
# Will be set in init
SLIDING_WINDOW: Optional[int] = None


def set_sliding_window(sliding_window: int):
    global SLIDING_WINDOW
    SLIDING_WINDOW = sliding_window


def get_sliding_windows() -> int:
    global SLIDING_WINDOW
    return SLIDING_WINDOW

90

91
92
93
94
95
96
97
def init_cpu_threads_env(rank_id: int, world_size: int):
    import importlib.util

    if importlib.util.find_spec("numa") is not None:
        import numa
        import psutil

98
        nodes = numa.info.get_max_node() + 1
99
100
101
102
103
104
105
106
        rank_per_node = math.ceil(world_size / nodes)
        num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes)
        node_id = int(rank_id / rank_per_node)
        rank_offset_per_node = rank_id % rank_per_node
        if os.getenv("OMP_NUM_THREADS") is None:
            num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1)
        else:
            num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS"))
107
108
        if len(numa.memory.get_membind_nodes()) == nodes:
            numa.memory.set_membind_nodes((node_id))
109
        torch.set_num_threads(num_cpus_per_rank)
110
        if len(numa.schedule.get_affinitive_cpus(0)) == psutil.cpu_count(logical=True):
111
            cpu_start = num_cpus_per_rank * rank_offset_per_node
112
            numa.schedule.run_on_cpus(
113
                0,
114
115
116
117
118
                *(
                    numa.info.node_to_cpus(node_id)[
                        cpu_start : cpu_start + num_cpus_per_rank
                    ]
                ),
119
            )
120
121
122
        logger.info(
            f"affinity={numa.schedule.get_affinitive_cpus(0)}, membind = {numa.memory.get_membind_nodes()}"
        )
123
124


125
126
127
128
@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
129
130
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
131
132

    # Decoder values
133
134
135
136
137
    # Can be a list for easy filtering
    # If `input_ids` is a list, it needs to be materialized to a tensor first
    input_ids: Union[torch.Tensor, List[List[int]]]
    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
    position_ids: Optional[torch.Tensor]
138
    speculative_ids: Optional[torch.Tensor]
139

140
141
    # Set when creating the batch
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
142
143
    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
    slot_indices: Optional[torch.Tensor]
144
145

    # list of length b of list of length s_i // block_size
146
    block_tables: List[List[int]]
147
    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
148
    block_tables_tensor: torch.Tensor
149
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
150
151
    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
    slots: Optional[torch.Tensor]
152

153
154
155
156
157
158
159
    max_input_length: int
    max_current_length: int

    # Whether this batch contains at least one request that is prefilling
    prefilling: bool
    # Whether each request is prefilling
    prefilling_mask: List[bool]
160

161
    # Prefill metadata tensors to efficiently compute logprobs
162
163
164
165
166
167
    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
    # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
    # as we only keep SLIDING_WINDOW values instead of the whole tensor
    prefill_cache_indices: Optional[torch.Tensor]
    # Will be set by `generate_token` and reset after each prefill forward
168
    prefill_head_indices: Optional[torch.Tensor]
169
    # Will be set by `generate_token` and reset after each prefill forward
170
    prefill_next_token_indices: Optional[torch.tensor]
171
    # Will be set by `generate_token` and reset after each prefill forward
172
    prefill_cu_outlens: Optional[List[int]]
173
174
    # Will be set by `generate_token` and reset after each prefill forward
    prefill_logprob_tokens: List[Optional[Tokens]]
Nicolas Patry's avatar
Nicolas Patry committed
175

176
177
    # All tokens
    all_input_ids: List[List[int]]
178
    all_input_ids_tensor: torch.Tensor
179
180
181

    # Lengths of all generations present in the batch
    input_lengths: List[int]
182
183
184
185
186
187
188
189
    # size [b], containing the number of blocks that can be retrieved from the cache
    cache_lengths: List[int]
    prompt_lengths: List[int]
    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
    input_lengths_tensor: Optional[torch.Tensor]
    cache_lengths_tensor: Optional[torch.Tensor]
    prompt_lengths_tensor: torch.Tensor

190
191
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
192
193

    # Generation helpers
194
    next_token_chooser: HeterogeneousNextTokenChooser
195
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
196
197
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
198

drbh's avatar
drbh committed
199
    # Adapter metadata for each request
200
201
    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
    adapter_meta: Optional[AdapterBatchMetadata]
drbh's avatar
drbh committed
202

203
    # Number of blocks in this batch
204
    num_blocks: int
205
206
    # Maximum number of blocks
    max_blocks: int
207

208
209
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
210
            id=self.batch_id,
211
            request_ids=[r.id for r in self.requests],
212
            size=len(self),
213
            max_tokens=self.num_blocks * BLOCK_SIZE,
214
215
216
217
218
            current_tokens=(
                sum([len(i) for i in self.input_ids])
                if isinstance(self.input_ids, list)
                else len(self.input_ids)
            ),
219
220
221
        )

    @classmethod
Daniël de Kok's avatar
Daniël de Kok committed
222
223
224
    def batch_tokenized_inputs(
        cls, requests: Iterable[generate_pb2.Request], tokenizer
    ):
225
226
227
        max_length = 0
        all_input_ids = []
        batch_size = 0
228
        for r in requests:
229
230
231
232
233
234
235
236
237
238
239
            batch_size += 1
            inputs = concat_text_chunks(r.input_chunks.chunks)
            input_ids = tokenizer(
                inputs,
                truncation=True,
                max_length=r.truncate,
                add_special_tokens=r.add_special_tokens,
            )["input_ids"]
            max_length = max(max_length, len(input_ids))
            all_input_ids.append(input_ids)
        return all_input_ids
240

drbh's avatar
drbh committed
241
242
243
244
245
246
247
248
249
    @classmethod
    def from_tokenized(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        batch_tokenized_inputs,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
250
        speculate = get_speculate()
251

252
        cache_lengths = []
253
        input_lengths = []
254
        prompt_lengths = []
255
256
        prefix_offsets = []
        read_offsets = []
257
        all_input_ids = []
258
        all_postfix_ids = []
259
        requests_idx_mapping = {}
260

261
        next_token_chooser_parameters = []
262
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
263
        top_n_tokens = []
264

265
        num_blocks = 0
266
267
        max_input_length = 0
        max_current_length = 0
268
        max_length = 0
269
        max_blocks = 0
270

271
272
        block_tables = []

273
        # Parse batch
274
275
276
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
277
278
279
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

280
281
282
283
            prompt_length = len(tokenized_input)
            prompt_lengths.append(prompt_length)

            cache_length = r.cache_len
Nicolas Patry's avatar
Nicolas Patry committed
284

285
            assert (
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
                cache_length <= prompt_length
            ), f"Prefix {cache_length} vs input {prompt_length}"
            if cache_length == prompt_length:
                assert False, "unreachable"

            # `chunk_len` is an optional field in the protobuf
            # It is only set if the model support chunking
            if r.HasField("chunk_len"):
                input_length = r.chunk_len

                if cache_length + input_length < prompt_length:
                    # FIXME: speculate is not supported for context chunking at the moment
                    assert speculate == 0
                    assert get_support_chunking()
                    assert input_length > 0

                postfix_ids = tokenized_input[
                    cache_length : cache_length + input_length
                ]
                assert (
                    len(postfix_ids) == input_length
                ), "Rust and Python tokenizers are not aligned"
            else:
                # Use all the remaining ids
                postfix_ids = tokenized_input[cache_length:]
                input_length = len(postfix_ids)

313
            input_lengths.append(input_length)
314

315
316
            prefix_offsets.append(prompt_length - 5)
            read_offsets.append(prompt_length)
317

318
            all_postfix_ids.append(postfix_ids)
319
            all_input_ids.append(tokenized_input)
320

321
            next_token_chooser_parameters.append(r.parameters)
322

323
324
325
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
326
            max_new_tokens = stopping_criteria.max_new_tokens
327
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
328
            top_n_tokens.append(r.top_n_tokens)
329

330
331
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
332
            speculative_length = get_speculate()
drbh's avatar
drbh committed
333
            speculative_length = 0 if speculative_length is None else speculative_length
Nicolas Patry's avatar
Nicolas Patry committed
334
335

            # Tokens that need to be mapped to blocks.
336
            block_tokens = prompt_length + max_new_tokens - 1 + speculative_length
337
338
339

            # blocks and slots can be empty (for example in warmup)
            if not r.blocks:
Nicolas Patry's avatar
Nicolas Patry committed
340
                needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
341
342
343
344
345
346
347
                request_blocks = [
                    b for b in range(num_blocks, num_blocks + needed_blocks)
                ]
            else:
                request_blocks = r.blocks

            block_tables.append(request_blocks)
Nicolas Patry's avatar
Nicolas Patry committed
348

349
            cache_lengths.append(cache_length)
350
            num_blocks += len(request_blocks)
351

352
            # Update
353
            max_blocks = max(max_blocks, len(request_blocks))
354
355
            max_input_length = max(max_input_length, input_length)
            max_current_length = max(max_current_length, cache_length + input_length)
OlivierDehaene's avatar
OlivierDehaene committed
356
            max_length = max(
357
358
                max_length,
                prompt_length + max_new_tokens + speculative_length,
OlivierDehaene's avatar
OlivierDehaene committed
359
            )
360
361

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
362
            next_token_chooser_parameters, dtype, device, tokenizer
363
364
365
366
367
368
369
370
        )

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
371

372
373
374
375
376
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

Nicolas Patry's avatar
Nicolas Patry committed
377
378
379
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
380

381
382
383
384
385
386
        block_tables_tensor = torch.zeros(
            (len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
        )
        for i, request_blocks in enumerate(block_tables):
            block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
        block_tables_tensor = block_tables_tensor.to(device)
387
388
389
        prompt_lengths_tensor = torch.tensor(
            prompt_lengths, dtype=torch.int32, device=device
        )
390

391
392
393
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
394
            requests_idx_mapping=requests_idx_mapping,
395
            input_ids=all_postfix_ids,
396
397
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
398
399
400
401
402
403
            cache_lengths=cache_lengths,
            max_input_length=max_input_length,
            max_current_length=max_current_length,
            prefilling=True,
            prefilling_mask=[True] * len(pb.requests),
            prefill_logprob_tokens=[None] * len(pb.requests),
404
            input_lengths=input_lengths,
405
            prompt_lengths=prompt_lengths,
406
407
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
408
            all_input_ids=all_input_ids,
409
410
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
411
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
412
413
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
414
            num_blocks=num_blocks,
415
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
416
            speculative_ids=None,
417
418
419
420
421
422
423
424
425
426
427
428
429
            prompt_lengths_tensor=prompt_lengths_tensor,
            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
            position_ids=None,
            cu_seqlen_prefill=None,
            prefill_cache_indices=None,
            slot_indices=None,
            slots=None,
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
            cache_lengths_tensor=None,
            input_lengths_tensor=None,
            adapter_meta=None,
430
431
        )

432
433
434
435
436
437
438
439
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
440
        assert len(pb.requests) > 0
441
442
443
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
        return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)

444
    @tracer.start_as_current_span("filter")
445
446
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
447
448
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
449
        if len(request_ids) == len(self):
450
451
            return self

452
        device = self.block_tables_tensor.device
453

454
455
456
        # New values after filtering
        requests_idx_mapping = {}

457
458
459
        # Used to index into tensors
        indices = []

460
461
462
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
463
464
        )

465
        # Create on CPU to only move to GPU once instead of at every copy
466
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
467
468
        max_input_length = 0
        max_current_length = 0
469

470
        requests = []
471
        block_tables = []
472
        all_input_ids = []
473
        input_ids = []
474

475
        prompt_lengths = []
476
        input_lengths = []
477
        cache_lengths = []
478
479
        prefix_offsets = []
        read_offsets = []
480

481
482
483
        prefilling_mask = []
        prefill_logprob_tokens = []

484
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
485
        top_n_tokens = []
drbh's avatar
drbh committed
486
        adapter_set = set()
487

488
        num_blocks = 0
489
490
491
492
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

493
494
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
495
            indices.append(idx)
496
497
498
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
499

500
501
502
503
            # Prefilling
            request_prefilling = self.prefilling_mask[idx]
            prefilling_mask.append(request_prefilling)

504
505
            # Get length
            request_input_length = self.input_lengths[idx]
506
507
508
509
510
            request_cache_length = self.cache_lengths[idx]
            max_input_length = max(max_input_length, request_input_length)
            max_current_length = max(
                max_current_length, request_cache_length + request_input_length
            )
511

512
513
            all_input_ids.append(self.all_input_ids[idx])

514
            prompt_lengths.append(self.prompt_lengths[idx])
515
            input_lengths.append(request_input_length)
516
            cache_lengths.append(request_cache_length)
517
518
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
519

520
521
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
522

Nicolas Patry's avatar
Nicolas Patry committed
523
            top_n_tokens.append(self.top_n_tokens[idx])
524
            prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
Nicolas Patry's avatar
Nicolas Patry committed
525

Nicolas Patry's avatar
Nicolas Patry committed
526
527
            ADAPTER_TO_INDEX = get_adapter_to_index()
            adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
drbh's avatar
drbh committed
528
529
            adapter_set.add(adapter_index)

530
            request_block_table = self.block_tables[idx]
531
            num_blocks += len(request_block_table)
532
533
            block_tables.append(request_block_table)

534
535
536
537
538
539
540
541
542
543
544
            # Input ids if the request was part of a prefilling batch
            # If the batch was decoding we can index into the tensor directly later
            if self.prefilling:
                input_ids.append(self.input_ids[idx])
            else:
                # Copy to tensor (CPU)
                slot_indices[i] = cumulative_max_length

                remaining_tokens = (
                    stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
                )
545

546
547
548
549
550
551
552
                # Set slice
                slot_filtering_indices[
                    self.slot_indices[idx] : self.slot_indices[idx]
                    + request_input_length
                    + remaining_tokens
                    - 1
                ] = True
553

554
                cumulative_max_length += request_input_length + remaining_tokens - 1
555

556
557
            max_blocks = max(max_blocks, len(request_block_table))

558
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
559
        block_tables_tensor = self.block_tables_tensor[indices]
560
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
561
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
OlivierDehaene's avatar
OlivierDehaene committed
562
563
564
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
        prompt_lengths_tensor = self.prompt_lengths_tensor[indices]

        if self.prefilling:
            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
            position_ids = None
            slot_indices = None
            slots = None
            cache_lengths_tensor = None
            input_lengths_tensor = None
            adapter_meta = None
        else:
            # Index into tensors
            input_ids = self.input_ids[indices]
            position_ids = self.position_ids[indices]
            adapter_indices = self.adapter_meta.adapter_indices[indices]
            input_lengths_tensor = self.input_lengths_tensor[indices]
            slots = self.slots[slot_filtering_indices]
            cache_lengths_tensor = self.cache_lengths_tensor[indices]

            # Move to GPU now that we have the whole tensor
            slot_indices = slot_indices.to(device)

            adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
            adapter_segments = torch.tensor(
                adapter_segments, dtype=torch.int32, device=device
            )
            adapter_meta = AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            )
drbh's avatar
drbh committed
597

598
        return type(self)(
599
600
601
602
603
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
604
            cu_seqlen_prefill=None,
605
            prefill_cache_indices=None,
606
607
608
609
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
610
611
612
613
            max_input_length=max_input_length,
            max_current_length=max_current_length,
            prefilling=self.prefilling,
            prefilling_mask=prefilling_mask,
614
615
616
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
617
618
619
            prefill_logprob_tokens=prefill_logprob_tokens,
            prompt_lengths=prompt_lengths,
            prompt_lengths_tensor=prompt_lengths_tensor,
620
            input_lengths=input_lengths,
621
            input_lengths_tensor=input_lengths_tensor,
622
623
            cache_lengths=cache_lengths,
            cache_lengths_tensor=cache_lengths_tensor,
624
625
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
626
627
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
628
            next_token_chooser=next_token_chooser,
629
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
630
631
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
632
            num_blocks=num_blocks,
633
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
634
            speculative_ids=speculative_ids,
635
            adapter_meta=adapter_meta,
636
637
638
639
640
641
642
643
644
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

645
        prefilling = False
646
        num_blocks = 0
647
648
649
650
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
651
652
        max_input_length = 0
        max_current_length = 0
653
654
        for b in batches:
            total_batch_size += len(b)
655
656
657
658
            max_blocks = max(max_blocks, b.max_blocks)
            # If `b` is prefilling and was just filtered, `b.slots` is None
            # `total_slots` is not used if any of the batches is prefilling
            total_slots += len(b.slots) if not b.prefilling else 0
659
            num_blocks += b.num_blocks
OlivierDehaene's avatar
OlivierDehaene committed
660
661
662
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
663
664
            max_input_length = max(max_input_length, b.max_input_length)
            max_current_length = max(max_current_length, b.max_current_length)
665
666
667
            max_length = max(
                max_length,
                max(
668
                    prompt_length
669
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
670
                    + speculative_length
671
672
                    for prompt_length, stopping_criteria in zip(
                        b.prompt_lengths, b.stopping_criterias
673
674
675
                    )
                ),
            )
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
            prefilling = prefilling or b.prefilling

        if prefilling:
            input_ids = []
            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
            position_ids = None
            slots = None
            slot_indices = None
            cache_lengths_tensor = None
            input_lengths_tensor = None
            adapter_meta = None
            adapter_segment_builder = None
        else:
            input_ids = batches[0].input_ids.new_empty(total_batch_size)
            position_ids = batches[0].position_ids.new_empty(total_batch_size)
            slots = batches[0].slots.new_empty(total_slots)
            slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
            input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
                total_batch_size
            )
            cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
                total_batch_size
            )
            total_indices_size = sum(
                b.adapter_meta.adapter_indices.shape[0] for b in batches
            )
            adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
                total_indices_size
            )
            adapter_segment_builder = SegmentConcatBuilder()
            adapter_set = set()
707

708
        prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
709
710
711
712
713
714
715
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
716
        )
Nicolas Patry's avatar
Nicolas Patry committed
717
718
719
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
720

721
        block_tables = []
722
        cache_lengths = []
723
724
        all_input_ids = []

725
        prompt_lengths = []
726
        input_lengths = []
727
728
        prefix_offsets = []
        read_offsets = []
729

730
731
        prefill_logprob_tokens = []

732
        next_token_chooser_parameters = []
733
        fsm_grammar_states = []
734
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
735
        top_n_tokens = []
736
        prefilling_mask = []
737

738
        # Cumulative length
739
        cumulative_batch_size = 0
740
        cumulative_slots = 0
drbh's avatar
drbh committed
741
        cumulative_adapter_indices_size = 0
742
743
744

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
745
746
747
748
749
750
751
752

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

753
754
755
756
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)

            # Copy tensors (GPU)
Nicolas Patry's avatar
Nicolas Patry committed
757
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
758
759
760
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
761

762
763
764
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
765
            prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
766

767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
            if not prefilling:
                slots_start_index = cumulative_slots
                slots_end_index = cumulative_slots + len(batch.slots)

                input_ids[start_index:end_index] = batch.input_ids
                position_ids[start_index:end_index] = batch.position_ids
                slots[slots_start_index:slots_end_index] = batch.slots
                slot_indices[start_index:end_index] = (
                    batch.slot_indices + cumulative_slots
                )
                input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
                cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor

                # Copy over adapter indices
                adapter_start_index = cumulative_adapter_indices_size
                adapter_end_index = (
                    cumulative_adapter_indices_size
                    + batch.adapter_meta.adapter_indices.shape[0]
                )
                adapter_indices[adapter_start_index:adapter_end_index] = (
                    batch.adapter_meta.adapter_indices
                )
                cumulative_adapter_indices_size = adapter_end_index
                adapter_set.update(batch.adapter_meta.adapter_set)
                adapter_segment_builder.concat(
                    batch.adapter_meta.adapter_segments,
                    batch.adapter_meta.segment_indices,
                )
Nicolas Patry's avatar
Nicolas Patry committed
795

796
797
798
799
800
801
                # Update
                cumulative_slots += len(batch.slots)
            else:
                if isinstance(batch.input_ids, torch.Tensor):
                    batch.input_ids = batch.input_ids.view(-1, 1).tolist()
                input_ids.extend(batch.input_ids)
802

803
            prefilling_mask.extend(batch.prefilling_mask)
804
            block_tables.extend(batch.block_tables)
805
            cache_lengths.extend(batch.cache_lengths)
806
807
            all_input_ids.extend(batch.all_input_ids)

808
            prompt_lengths.extend(batch.prompt_lengths)
809
            input_lengths.extend(batch.input_lengths)
810
811
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
812

813
814
            prefill_logprob_tokens.extend(batch.prefill_logprob_tokens)

815
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
816
            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
817
818
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
819
820
            top_n_tokens.extend(batch.top_n_tokens)

821
            # Update
822
            cumulative_batch_size += len(batch)
823

824
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
825
826
827
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
drbh's avatar
drbh committed
828
            tokenizer=batches[0].next_token_chooser.tokenizer,
829
            fsm_grammar_states=fsm_grammar_states,
830
831
        )

OlivierDehaene's avatar
OlivierDehaene committed
832
833
834
835
836
        speculative_ids = (
            torch.cat([b.speculative_ids for b in batches], dim=0)
            if batches[0].speculative_ids is not None
            else None
        )
Nicolas Patry's avatar
Nicolas Patry committed
837

838
839
840
841
842
843
844
845
        if adapter_segment_builder is not None:
            adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
            adapter_meta = AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            )
drbh's avatar
drbh committed
846

847
        return cls(
848
849
            batch_id=batches[0].batch_id,
            requests=requests,
850
            requests_idx_mapping=requests_idx_mapping,
851
852
            input_ids=input_ids,
            position_ids=position_ids,
853
            cu_seqlen_prefill=None,
854
            prefill_cache_indices=None,
855
856
857
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
858
859
            cache_lengths=cache_lengths,
            cache_lengths_tensor=cache_lengths_tensor,
860
            slots=slots,
861
862
863
864
            max_input_length=max_input_length,
            max_current_length=max_current_length,
            prefilling=prefilling,
            prefilling_mask=prefilling_mask,
865
866
867
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
868
869
870
            prefill_logprob_tokens=prefill_logprob_tokens,
            prompt_lengths=prompt_lengths,
            prompt_lengths_tensor=prompt_lengths_tensor,
871
            input_lengths=input_lengths,
872
            input_lengths_tensor=input_lengths_tensor,
873
874
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
875
876
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
877
            next_token_chooser=next_token_chooser,
878
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
879
880
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
881
            num_blocks=num_blocks,
882
            max_blocks=max_blocks,
OlivierDehaene's avatar
OlivierDehaene committed
883
            speculative_ids=speculative_ids,
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
            adapter_meta=adapter_meta,
        )

    def prepare_for_prefill(self):
        # Prepare values if we need to continue prefilling
        # Speculation must be ignored while we prefill even with chunking
        # it simplifies everything
        assert self.speculative_ids is None

        sliding_window = get_sliding_windows()
        position_ids = []
        cu_seqlen_prefill = [0]
        slot_indices = []
        prefill_cache_indices = []
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

        # Cumulative length
        cumulative_length = 0
        cumulative_slot_tokens = 0
        prefill_out_cumulative_length = 0

        slots = []
        adapter_indices_list = []
        adapter_set = set()

        for i, (
            r,
            cache_length,
            input_length,
            prompt_length,
            request_prefilling,
            blocks,
        ) in enumerate(
            zip(
                self.requests,
                self.cache_lengths,
                self.input_lengths,
                self.prompt_lengths,
                self.prefilling_mask,
                self.block_tables,
            )
        ):
            next_chunk_length = input_length
            # Position ids
            request_position_ids = torch.arange(
                cache_length, cache_length + input_length, dtype=torch.int32
            )
            position_ids.append(request_position_ids)

            # Add cumulative lengths of all previous inputs
            cu_seqlen_prefill.append(cumulative_length + input_length)

            if not r.slots:
                request_slots = [
                    s
                    for b in blocks
                    for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
                ]
            else:
                request_slots = r.slots

            request_slots = request_slots[cache_length:]
            request_slot_indices = torch.arange(
                cumulative_slot_tokens,
                cumulative_slot_tokens + input_length,
                dtype=torch.int64,
            )

            # Create tensor to slice into the kv tensor in prefill
            if sliding_window is not None:
                request_prefill_cache_indices = torch.arange(
                    cumulative_length + max(0, input_length - sliding_window),
                    cumulative_length + input_length,
                    dtype=torch.int64,
                )

            # Prefill logprobs is ignored if the request is done prefilling
            prefill_logprobs = r.prefill_logprobs and request_prefilling

            all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs

            if prefill_logprobs:
                prefill_head_indices.append(
                    torch.arange(
                        cumulative_length,
                        cumulative_length + input_length,
                        dtype=torch.int64,
                    )
                )
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1],
                        dtype=torch.int64,
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

            slots.extend(request_slots)
            slot_indices.append(request_slot_indices)

            if sliding_window is not None:
                prefill_cache_indices.append(request_prefill_cache_indices)

            ADAPTER_TO_INDEX = get_adapter_to_index()
            adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
            adapter_indices_list.append(torch.full((next_chunk_length,), adapter_index))
            adapter_set.add(adapter_index)

            # Update
            cumulative_length += next_chunk_length
            cumulative_slot_tokens += len(request_slots)

        device = self.block_tables_tensor.device

        if isinstance(self.input_ids, list):
            if len(self) > 1:
                input_ids = np.concatenate(self.input_ids, dtype=np.int64)
            else:
                input_ids = self.input_ids[0]
            self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)

        if len(self) > 1:
            position_ids = torch.cat(position_ids)
            slot_indices = torch.cat(slot_indices)
            if sliding_window is not None:
                prefill_cache_indices = torch.cat(prefill_cache_indices)
        else:
            position_ids = position_ids[0]
            slot_indices = slot_indices[0]
            if sliding_window is not None:
                prefill_cache_indices = prefill_cache_indices[0]

        self.prefill_cu_outlens = prefill_cu_outlens
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
        )
        self.cu_seqlen_prefill = cu_seqlen_prefill
        self.position_ids = position_ids.to(device)
        self.slot_indices = slot_indices.to(device)
        self.prefill_cache_indices = (
            prefill_cache_indices.to(device) if sliding_window is not None else None
        )
        self.input_lengths_tensor = torch.tensor(
            self.input_lengths, dtype=torch.int32, device=device
        )

        if all_prefill_logprobs:
            prefill_head_indices = None
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
        elif no_prefill_logprobs:
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.cat(prefill_head_indices).to(device)
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )

        self.prefill_head_indices = prefill_head_indices
        self.prefill_next_token_indices = prefill_next_token_indices
        self.slots = torch.tensor(slots, dtype=torch.int64, device=device)
        self.cache_lengths_tensor = torch.tensor(
            self.cache_lengths, dtype=torch.int32, device=device
        )
        adapter_indices = torch.cat(adapter_indices_list).to(
            dtype=torch.int64, device=device
        )
        adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
        adapter_segments = torch.tensor(
            adapter_segments, dtype=torch.int32, device=device
        )
        self.adapter_meta = AdapterBatchMetadata(
            adapter_indices=adapter_indices,
            adapter_set=adapter_set,
            adapter_segments=adapter_segments,
            segment_indices=adapter_segment_indices,
1073
1074
1075
1076
1077
1078
        )

    def __len__(self):
        return len(self.requests)


1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
ADAPTER_LAYERS = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}


1091
1092
1093
class FlashCausalLM(Model):
    def __init__(
        self,
drbh's avatar
drbh committed
1094
        model_id: str,
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
        model_class,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        speculator: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
        lora_adapter_ids: Optional[list] = [],
        tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
        config_class: PreTrainedTokenizerBase = AutoConfig,
        default_dtype=torch.float16,
        aliases=None,
        # Used for Santacoder override of config
1107
1108
1109
        num_kv_heads: Optional[int] = None,
        # Deepseek V2 uses different QK and V dims.
        head_size: Optional[int] = None,
1110
        skip_special_tokens: bool = True,
1111
        kv_cache_dtype: Optional[torch.dtype] = None,
1112
        support_chunking: bool = True,
1113
    ):
Nicolas Patry's avatar
Nicolas Patry committed
1114
        self.quantize = quantize
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = default_dtype if dtype is None else dtype
        elif SYSTEM == "ipex":
            if hasattr(torch, "xpu") and torch.xpu.is_available():
                device = torch.device(f"xpu:{rank}")
                dtype = default_dtype if dtype is None else dtype
            else:
                device = torch.device("cpu")
                dtype = torch.bfloat16 if dtype is None else dtype
1126
                init_cpu_threads_env(rank_id=rank, world_size=world_size)
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
        else:
            raise NotImplementedError(f"{model_class} is only available on GPU")

        tokenizer = tokenizer_class.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )
        try:
            generation_config = GenerationConfig.from_pretrained(
                model_id, revision=revision, trust_remote_code=trust_remote_code
            )
            if isinstance(generation_config.eos_token_id, (list, set)):
                # TODO Huge hack
                tokenizer._eos_token_ids = set(generation_config.eos_token_id)
        except Exception:
            pass

        config = config_class.from_pretrained(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        config.quantize = quantize
        config.speculator = speculator

        torch.distributed.barrier(group=self.process_group)

1155
        weights_loader = get_loader(quantize, model_id, revision)
1156
1157
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(
1158
1159
1160
1161
1162
1163
            filenames,
            device,
            dtype,
            process_group=self.process_group,
            aliases=aliases,
            weights_loader=weights_loader,
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
        )

        prefix = ""
        model = model_class(prefix, config, weights)
        torch.distributed.barrier(group=self.process_group)

        # VLM models define the config we care about in their text_config
        text_config = getattr(config, "text_config", None)
        if text_config is not None:
            config = text_config
1174
1175
1176
1177
1178
1179

        if getattr(config, "sliding_window", None) is not None:
            set_sliding_window(config.sliding_window)
        else:
            config.sliding_window = None

1180
        self.num_layers = config.num_hidden_layers
1181
        self.num_heads = config.num_attention_heads // self.process_group.size()
1182
1183
        # Validation is done in the model itself
        if num_kv_heads is None:
1184
1185
            num_kv_heads = getattr(config, "num_key_value_heads", None)
            # GPT-2 workaround
1186
            if num_kv_heads is None:
1187
1188
1189
                num_kv_heads = getattr(config, "n_head", None)
        if num_kv_heads is None:
            raise ValueError("Cannot get the number of key/value heads")
1190
1191
1192
1193
1194
1195
        self.num_kv_heads = (
            num_kv_heads // self.process_group.size()
            if num_kv_heads > 1
            else num_kv_heads
        )
        assert self.num_kv_heads > 0
1196
1197

        if head_size is None:
Nicolas Patry's avatar
Nicolas Patry committed
1198
1199
1200
1201
1202
1203
            # Some models use GQA and different sizes for o_proj
            # and q_proj, that allows for that.
            if hasattr(config, "head_dim"):
                self.head_size = config.head_dim
            else:
                self.head_size = config.hidden_size // config.num_attention_heads
1204
1205
        else:
            self.head_size = head_size
1206

1207
        self.cuda_graphs = {}
1208
        self.kv_cache = []
1209
        self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
1210

1211
        if ATTENTION == "flashinfer":
Nicolas Patry's avatar
Nicolas Patry committed
1212
            from text_generation_server.layers.attention.flashinfer import (
1213
1214
                create_prefill_state,
                create_decode_state,
Nicolas Patry's avatar
Nicolas Patry committed
1215
                create_prefill_with_paged_kv_state,
1216
1217
1218
            )

            self.prefill_state = create_prefill_state(device=device)
Nicolas Patry's avatar
Nicolas Patry committed
1219
1220
1221
            self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
                device=device
            )
1222

Nicolas Patry's avatar
Nicolas Patry committed
1223
1224
1225
1226
1227
            self.decode_state = create_decode_state(
                device=device,
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
            )
1228

1229
        super().__init__(
drbh's avatar
drbh committed
1230
            model_id=model_id,
1231
            model=model,
1232
1233
1234
1235
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
1236
1237
            rank=rank,
            world_size=world_size,
1238
            sliding_window=config.sliding_window,
1239
            support_chunking=support_chunking,
1240
1241
1242
1243
1244
1245
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
    def max_past(self) -> int:
        return getattr(self.model, "max_past", None)

    def init_kv_cache(
        self,
        num_blocks: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        self.kv_cache = []
        empty_cache()
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
        self.kv_cache = [
            KVCache(
                num_blocks=num_blocks,
                num_heads=num_heads,
                head_size=head_size,
                dtype=dtype,
                device=device,
            )
            for _ in range(num_layers)
        ]
1270

1271
1272
1273
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
1274
        slots = torch.arange(bs, dtype=torch.int64, device=self.device)
Nicolas Patry's avatar
Nicolas Patry committed
1275
        input_lengths = [max_s] * bs
1276
        cache_lengths = [0] * bs
Nicolas Patry's avatar
Nicolas Patry committed
1277
1278
        input_lengths_tensor = (
            torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
1279
        )
1280
        cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
Nicolas Patry's avatar
Nicolas Patry committed
1281
1282
1283
1284
1285
1286
1287
1288
1289
        block_tables = torch.arange(
            max_bt, dtype=torch.int32, device=self.device
        ).repeat(bs)
        block_tables = block_tables.reshape((bs, max_bt))

        if ATTENTION == "flashinfer":
            block_tables = block_tables_to_ragged(
                block_tables=block_tables,
                input_lengths=input_lengths,
1290
                cache_lengths=cache_lengths,
Nicolas Patry's avatar
Nicolas Patry committed
1291
1292
            )
            from text_generation_server.layers.attention.flashinfer import (
1293
1294
1295
1296
1297
1298
1299
1300
1301
                create_decode_state_cuda_graphs,
            )

            block_tables_ptr = torch.zeros(
                bs + 1, dtype=torch.int32, device=self.device
            )
            last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
            state = create_decode_state_cuda_graphs(
                device=input_ids.device,
Nicolas Patry's avatar
Nicolas Patry committed
1302
                block_tables=block_tables,
1303
1304
1305
1306
1307
1308
1309
1310
                block_tables_ptr=block_tables_ptr,
                last_page_len=last_page_len,
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
            )
        else:
            state = None

1311
1312
1313
1314
1315
1316
1317
1318
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "kv_cache": self.kv_cache,
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths_tensor,
1319
            "cache_lengths": cache_lengths_tensor,
1320
1321
1322
1323
            "state": state,
            "graph": graph,
        }

1324
1325
        torch.cuda.synchronize()
        # Run once outside to warmup
1326
        with self._forward_context(
1327
            block_tables=block_tables,
1328
            cu_seqlen_prefill=None,
Nicolas Patry's avatar
Nicolas Patry committed
1329
            input_lengths_tensor=input_lengths_tensor,
1330
            state=state,
1331
            cache_lengths_tensor=cache_lengths_tensor,
1332
        ):
1333
1334
            seqlen = Seqlen(
                input_lengths=input_lengths_tensor,
1335
                cache_lengths=cache_lengths_tensor,
1336
1337
1338
1339
                cu_seqlen_q=None,
                max_q=1,
                max_k=max_s,
            )
1340
            self.model.forward(
1341
1342
1343
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
1344
                kv_cache=self.kv_cache,
1345
1346
                block_tables=block_tables,
                slots=slots,
1347
                seqlen=seqlen,
1348
                max_s=max_s,
1349
                prefill_cache_indices=None,
1350
1351
                lm_head_indices=None,
            )
1352
            del seqlen
1353
1354
1355
1356

            torch.cuda.synchronize()

            with torch.cuda.graph(graph, pool=MEM_POOL):
1357
1358
                seqlen = Seqlen(
                    input_lengths=input_lengths_tensor,
1359
                    cache_lengths=cache_lengths_tensor,
1360
1361
1362
1363
                    cu_seqlen_q=None,
                    max_q=1,
                    max_k=max_s,
                )
1364
1365
1366
1367
1368
1369
1370
                logits, speculative_logits = self.model.forward(
                    input_ids=input_ids,
                    position_ids=position_ids,
                    cu_seqlen_prefill=None,
                    kv_cache=self.kv_cache,
                    block_tables=block_tables,
                    slots=slots,
1371
                    seqlen=seqlen,
1372
1373
1374
1375
1376
1377
                    max_s=max_s,
                    prefill_cache_indices=None,
                    lm_head_indices=None,
                )
                self.cuda_graphs[bs]["logits"] = logits
                self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
1378
1379
        torch.cuda.synchronize()

1380
    def warmup(self, batch: FlashCausalLMBatch):
1381
        # The warmup batch is the biggest batch we could ever receive
1382
        self.kv_cache = []
Nicolas Patry's avatar
Nicolas Patry committed
1383
1384
        empty_cache()

1385
        try:
1386
1387
            self.init_kv_cache(
                batch.num_blocks,
1388
1389
1390
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
1391
                self.kv_cache_dtype,
1392
1393
                self.device,
            )
1394
            max_bt = batch.max_blocks
1395
            max_s = max_bt * BLOCK_SIZE
fxmarty's avatar
fxmarty committed
1396
1397
1398

            if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
                torch.cuda.tunable.tuning_enable(False)
1399
            _, batch, _ = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
1400
        except torch.cuda.OutOfMemoryError as e:
1401
            raise RuntimeError(
1402
                f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. "
1403
                f"You need to decrease `--max-batch-prefill-tokens`"
1404
            ) from e
1405

Nicolas Patry's avatar
Nicolas Patry committed
1406
        synchronize(self.device)
1407

1408
1409
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
1410
        dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
1411
1412
1413
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

Nicolas Patry's avatar
Nicolas Patry committed
1414
        free_memory = get_free_memory(self.device, MEMORY_FRACTION)
drbh's avatar
drbh committed
1415
        batch_num_blocks = batch.num_blocks if batch is not None else 0
1416
1417

        num_blocks = (
1418
            # Leave 5% for some wiggle room
1419
            int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size)
1420
            # Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
drbh's avatar
drbh committed
1421
            + batch_num_blocks
1422
1423
        )

1424
1425
        log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")

1426
        del batch
1427

1428
        self.init_kv_cache(
1429
1430
1431
1432
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
1433
            self.kv_cache_dtype,
1434
1435
1436
            self.device,
        )

fxmarty's avatar
fxmarty committed
1437
1438
1439
1440
1441
        if SYSTEM == "rocm":
            if (
                os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
                or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
            ):
1442
1443
                torch.cuda.tunable.enable()

fxmarty's avatar
fxmarty committed
1444
1445
1446
1447
1448
1449
1450
1451
                if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
                    torch.cuda.tunable.tuning_enable(True)

                if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None:
                    tuning_sequences = [
                        int(val)
                        for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
                    ]
1452
                elif CUDA_GRAPHS is not None:
fxmarty's avatar
fxmarty committed
1453
                    tuning_sequences = CUDA_GRAPHS
1454
                else:
1455
                    tuning_sequences = [1, 2, 3, 4, 5, 6, 7]
fxmarty's avatar
fxmarty committed
1456
1457
1458

                tunableop_filepath = os.path.join(
                    HUGGINGFACE_HUB_CACHE,
drbh's avatar
drbh committed
1459
                    f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
fxmarty's avatar
fxmarty committed
1460
1461
                )

1462
1463
                log_master(
                    logger.info,
1464
1465
1466
1467
1468
                    f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
                )

                torch.cuda.tunable.set_filename(
                    tunableop_filepath, insert_device_ordinal=False
fxmarty's avatar
fxmarty committed
1469
1470
1471
                )

                if os.path.isfile(tunableop_filepath):
1472
1473
1474
                    log_master(
                        logger.info,
                        f"The file {tunableop_filepath} already exists and will be reused.",
fxmarty's avatar
fxmarty committed
1475
1476
1477
1478
1479
1480
                    )
                    torch.cuda.tunable.read_file(tunableop_filepath)

                os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)

                for seqlen in tuning_sequences:
1481
                    log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
fxmarty's avatar
fxmarty committed
1482
1483
                    self.tunableop_warmup(seqlen)
                    torch.cuda.tunable.write_file(tunableop_filepath)
1484
1485
                if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
                    torch.cuda.tunable.tuning_enable(False)
fxmarty's avatar
fxmarty committed
1486
            else:
1487
1488
1489
                log_master(
                    logger.info,
                    "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
fxmarty's avatar
fxmarty committed
1490
1491
                )

1492
        if CUDA_GRAPHS:
1493
            try:
1494
1495
1496
                log_master(
                    logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
                )
1497
                # Warmup cuda graphs
1498
                for bs in CUDA_GRAPHS:
1499
1500
                    if self.speculate is None or self.speculate + 1 <= bs:
                        self.cuda_graph_warmup(bs, max_s, max_bt)
OlivierDehaene's avatar
OlivierDehaene committed
1501
            except torch.cuda.OutOfMemoryError:
1502
                logger.exception("Decode cuda graph warmup failed")
1503
        else:
1504
1505
1506
            log_master(
                logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
            )
1507

1508
        return int(num_blocks * BLOCK_SIZE)
1509

fxmarty's avatar
fxmarty committed
1510
1511
1512
1513
1514
    def tunableop_warmup(self, seqlen: int):
        input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
        slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)

fxmarty's avatar
fxmarty committed
1515
1516
        # Dummy value, some models (starcoder2) don't accept `None`.
        input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
1517
1518
1519
        cache_lengths_tensor = torch.zeros(
            seqlen, dtype=torch.int32, device=self.device
        )
1520
1521
1522
        cu_seqlen_prefill = torch.tensor(
            [0, seqlen], device=self.device, dtype=torch.int32
        )
1523
        max_s = seqlen
1524
1525
        seqlen = Seqlen(
            input_lengths=input_lengths,
1526
            cache_lengths=cache_lengths_tensor,
1527
1528
1529
1530
            cu_seqlen_q=cu_seqlen_prefill,
            max_q=1,
            max_k=seqlen,
        )
fxmarty's avatar
fxmarty committed
1531

fxmarty's avatar
fxmarty committed
1532
1533
1534
1535
        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
1536
            cu_seqlen_prefill=cu_seqlen_prefill,
1537
            kv_cache=self.kv_cache,
fxmarty's avatar
fxmarty committed
1538
            block_tables=None,
1539
            seqlen=seqlen,
fxmarty's avatar
fxmarty committed
1540
            slots=slots,
1541
            max_s=max_s,
fxmarty's avatar
fxmarty committed
1542
            lm_head_indices=None,
1543
            prefill_cache_indices=None,
fxmarty's avatar
fxmarty committed
1544
1545
        )

1546
    def forward(
drbh's avatar
drbh committed
1547
        self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
1548
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1549
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
1550
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
1551
1552
1553
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
1554
            kv_cache = self.kv_cache
OlivierDehaene's avatar
OlivierDehaene committed
1555
1556
1557
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
1558
            max_s = batch.max_current_length
OlivierDehaene's avatar
OlivierDehaene committed
1559
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
1560
1561
1562

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
1563
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
1564
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
1565
1566
1567
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1568
1569
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
1570
1571
1572
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1573
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
OlivierDehaene's avatar
OlivierDehaene committed
1574
1575
1576
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
1577
1578
            cache_lengths_tensor = (
                batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
Nicolas Patry's avatar
Nicolas Patry committed
1579
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1580
1581

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
1582
1583
1584
1585
1586
1587
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
1588
1589
1590
1591
1592
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
1593
1594
1595
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
1596
            kv_cache = self.kv_cache
OlivierDehaene's avatar
OlivierDehaene committed
1597
1598
1599
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
1600
1601
            cache_lengths_tensor = batch.cache_lengths_tensor
            max_s = batch.max_current_length
OlivierDehaene's avatar
OlivierDehaene committed
1602
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
1603

1604
1605
1606
1607
1608
1609
        if cu_seqlen_prefill is None and self.max_past() is not None:
            # In decode, not prefill, we're actually overwriting the KV-cache
            # in a circular buffer mode.
            # This makes sure the max_s for the decode pass is correct.
            max_s = min(self.max_past(), max_s)

1610
        bs = input_ids.shape[0]
OlivierDehaene's avatar
OlivierDehaene committed
1611
1612
1613
1614
1615
1616
1617
1618
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None

        if cu_seqlen_prefill is not None or cuda_graph is None:
1619
            if ATTENTION == "flashinfer":
Nicolas Patry's avatar
Nicolas Patry committed
1620
1621
1622
                block_tables = block_tables_to_ragged(
                    block_tables=block_tables,
                    input_lengths=batch.input_lengths,
1623
                    cache_lengths=batch.cache_lengths,
Nicolas Patry's avatar
Nicolas Patry committed
1624
                )
1625
            with self._forward_context(
1626
                block_tables=block_tables,
1627
                cu_seqlen_prefill=cu_seqlen_prefill,
1628
                input_lengths_tensor=input_lengths,
1629
                cache_lengths_tensor=cache_lengths_tensor,
1630
            ):
1631
1632
                seqlen = Seqlen(
                    input_lengths=input_lengths,
1633
                    cache_lengths=cache_lengths_tensor,
1634
                    cu_seqlen_q=cu_seqlen_prefill,
1635
1636
                    max_q=batch.max_input_length,
                    max_k=batch.max_current_length,
1637
                )
1638
1639
1640
1641
1642
1643
1644
                logits, speculative_logits = self.model.forward(
                    input_ids=input_ids,
                    position_ids=position_ids,
                    cu_seqlen_prefill=cu_seqlen_prefill,
                    kv_cache=kv_cache,
                    block_tables=block_tables,
                    slots=slots,
1645
                    seqlen=seqlen,
1646
1647
1648
1649
1650
1651
1652
1653
                    max_s=max_s,
                    prefill_cache_indices=batch.prefill_cache_indices,
                    lm_head_indices=lm_head_indices,
                    adapter_data=adapter_data,
                )
                if batch.prefill_cache_indices is not None:
                    batch.prefill_cache_indices = None
                return logits, speculative_logits
1654
1655
1656
1657
1658

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
Nicolas Patry's avatar
Nicolas Patry committed
1659
1660
1661
1662
        if ATTENTION == "flashinfer":
            block_tables = block_tables_to_ragged(
                block_tables=block_tables,
                input_lengths=batch.input_lengths,
1663
                cache_lengths=batch.cache_lengths,
Nicolas Patry's avatar
Nicolas Patry committed
1664
            )
1665
            # assert block_tables.shape[0] >= slots.shape[0]
Nicolas Patry's avatar
Nicolas Patry committed
1666
1667
1668
1669
1670
            cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
        else:
            cuda_graph["block_tables"][
                : block_tables.shape[0], : block_tables.shape[1]
            ] = block_tables
1671
1672
1673
1674

        # XXX: This is working only because block 0 is reserved for the healthcheck
        # so it doesn't matter if we override it with bogus values.
        cuda_graph["slots"].fill_(0)
1675
1676
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
1677
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
1678
1679
1680
1681
        cuda_graph["cache_lengths"].zero_()
        cuda_graph["cache_lengths"][
            : cache_lengths_tensor.shape[0]
        ] = cache_lengths_tensor
1682

1683
        with self._forward_context(
Nicolas Patry's avatar
Nicolas Patry committed
1684
            block_tables=cuda_graph["block_tables"],
1685
            cu_seqlen_prefill=None,
Nicolas Patry's avatar
Nicolas Patry committed
1686
            input_lengths_tensor=cuda_graph["input_lengths"],
1687
            cache_lengths_tensor=cuda_graph["cache_lengths"],
1688
            state=cuda_graph["state"],
1689
1690
1691
1692
        ):
            # Replay the graph
            cuda_graph["graph"].replay()

1693
        # Slice output to the correct shape
1694
1695
1696
1697
1698
1699
1700
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
1701
1702
1703
1704

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
1705
1706
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
1707
1708
1709
1710
        prefill = batch.prefilling
        if prefill:
            batch.prepare_for_prefill()

1711
        prefill_logprobs = batch.prefill_next_token_indices is not None
1712

drbh's avatar
drbh committed
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
        # Update adapter indices for speculative tokens (if present)
        adapter_meta = batch.adapter_meta
        if batch.speculative_ids is not None:
            B, speculative_length = batch.speculative_ids.shape
            new_length = speculative_length + 1
            adapter_indices = (
                adapter_meta.adapter_indices.unsqueeze(-1)
                .expand(B, new_length)
                .reshape(-1)
            )
            adapter_segments = adapter_meta.adapter_segments * new_length
            adapter_meta = AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_meta.adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_meta.segment_indices,
            )

        # Assign pointers to adapter weights
        # TODO(travis): don't update this if indices haven't changed
        adapter_data = AdapterBatchData.from_meta(
            adapter_meta,
            self.layer_to_adapter_weights,
            prefill,
            batch.prefill_head_indices,
        )

        out, speculative_logits = self.forward(batch, adapter_data)
1741

1742
1743
        if prefill:
            next_token_logits = (
1744
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
1745
            )
Nicolas Patry's avatar
Nicolas Patry committed
1746
1747
            if speculative_logits is not None:
                speculative_logits = (
OlivierDehaene's avatar
OlivierDehaene committed
1748
1749
1750
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
Nicolas Patry's avatar
Nicolas Patry committed
1751
                )
1752
1753
1754
1755
            if len(batch) > 1 and prefill_logprobs:
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
1756
        else:
1757
            prefill_logprobs = None
1758
            next_token_logits = out
drbh's avatar
drbh committed
1759
            next_adapter_indices = batch.adapter_meta.adapter_indices
1760

1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
        finished_prefilling = True
        next_chunk_lengths = []
        current_prefilling_mask = batch.prefilling_mask
        if prefill:
            if get_support_chunking():
                next_prefilling_mask = []
                # Budget in tokens for the next batch
                # We remove (len(batch) - 1) to always have enough space for at least a single decode
                # for the remaining requests -1 because the first request does not need to be removed from the budget
                # (ex: you have one request in the batch, you want it to take the full budget not budget -1)
                batch_budget = get_max_prefill_tokens() - (len(batch) - 1)
                # We reverse to prioritize older requests
                # zip() is not reversible so reverse the underlying lists instead
                for cache_length, input_length, prompt_length in zip(
                    reversed(batch.cache_lengths),
                    reversed(batch.input_lengths),
                    reversed(batch.prompt_lengths),
                ):
                    remaining_prefill_tokens = max(
                        prompt_length - cache_length - input_length, 0
                    )
                    if remaining_prefill_tokens > 0:
                        next_chunk_length = max(
                            min(remaining_prefill_tokens, batch_budget), 1
                        )
                        batch_budget -= next_chunk_length
                        finished_prefilling = False
                        next_prefilling_mask.append(True)
                    else:
                        # FIXME: use true number of accepted tokens instead of 1
                        # Since speculation will be turned off, this is always true
                        next_chunk_length = 1
                        next_prefilling_mask.append(False)
                    next_chunk_lengths.append(next_chunk_length)

                # Reverse back the obtained values²
                next_chunk_lengths.reverse()
                next_prefilling_mask.reverse()
            else:
                # The model does not support chunking
                # We know we only do a single prefill
                finished_prefilling = True
                next_prefilling_mask = [False] * len(batch)

            batch.prefilling = not finished_prefilling
            batch.prefilling_mask = next_prefilling_mask

Nicolas Patry's avatar
Nicolas Patry committed
1808
        speculate = get_speculate()
OlivierDehaene's avatar
OlivierDehaene committed
1809
1810
1811
1812
1813
1814
1815
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
1816
            batch.all_input_ids_tensor[:, : batch.max_current_length],
OlivierDehaene's avatar
OlivierDehaene committed
1817
            next_token_logits,
Nicolas Patry's avatar
Nicolas Patry committed
1818
            speculate,
OlivierDehaene's avatar
OlivierDehaene committed
1819
1820
            batch.speculative_ids,
            speculative_logits,
1821
1822
        )

Nicolas Patry's avatar
Nicolas Patry committed
1823
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
1824
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1825
1826
        )

1827
1828
1829
        # Since we are done prefilling, all the tensors that were concatenating values for all the requests
        # instantly become of shape [BATCH_SIZE]
        if prefill and finished_prefilling:
1830
            next_position_ids = batch.position_ids.new_empty(len(batch))
1831
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
1832
1833
1834
1835
            next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
                len(batch)
            )
        elif not prefill:
1836
1837
            next_position_ids = batch.position_ids

1838
        # Zipped iterator
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
        iterator = zip(
            batch.requests,
            batch.prompt_lengths,
            batch.cache_lengths,
            batch.input_lengths,
            batch.all_input_ids,
            accepted_ids,
            current_prefilling_mask,
            batch.prefilling_mask,
        )
1849

1850
1851
1852
1853
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

1854
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1855
        index = 0
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
        # Cumulative length
        cumulative_length = 0
        for i, (
            request,
            prompt_length,
            cache_length,
            input_length,
            all_input_ids,
            n_accepted_ids,
            request_was_prefilling,
            request_is_prefilling,
        ) in enumerate(iterator):
            if prefill and finished_prefilling:
1869
                # Indexing metadata
1870
1871
                _start_index = cumulative_length
                end_index = cumulative_length + input_length
1872

1873
1874
1875
1876
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

drbh's avatar
drbh committed
1877
1878
1879
1880
1881
1882
                # Initialize adapter indices
                # In decode, we only have one token per row in the batch, so grab last index
                next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
                    end_index - 1
                ]

1883
1884
1885
1886
1887
1888
            # Used to gather prefill logprobs
            # Copy batch.all_input_ids_tensor to prefill_token_indices
            if request.prefill_logprobs and request_was_prefilling:
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
1889

1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
                # Logprobs generated by the model are for the next token
                # So we need to translate the id tensor by 1
                ids = batch.all_input_ids_tensor[
                    i, cache_length + 1 : cache_length + input_length + 1
                ]
                if len(batch) > 1:
                    prefill_tokens_indices[out_start_index:out_end_index] = ids
                else:
                    # Set prefill_tokens_indices to the correct slice
                    prefill_tokens_indices = ids

            if not request_is_prefilling:
                # Only save tokens if we are done prefilling for this request
                for j in range(n_accepted_ids):
                    batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
                        next_input_ids[index + j]
                    )
            index += n_accepted_ids
1908
1909
            cumulative_length += input_length

drbh's avatar
drbh committed
1910
        # Update values
1911
1912
1913
1914
1915
1916
1917
1918
1919
        # These values can be updated without a GPU -> CPU sync
        if not prefill or (prefill and finished_prefilling):
            batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
            batch.speculative_ids = speculative_ids
            batch.position_ids = next_position_ids + accepted_ids
            batch.cache_lengths_tensor += batch.input_lengths_tensor
            batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32)
            batch.slot_indices += accepted_ids
            batch.adapter_meta.adapter_indices = next_adapter_indices
1920

1921
        if prefill and prefill_logprobs:
1922
1923
1924
            # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size))
            torch.log_softmax(out, -1, out=out)
            prefill_logprobs_tensor = out
1925
1926
1927
1928
1929
1930
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
        # Does a GPU <-> CPU sync internally
        if prefill and finished_prefilling:
            # adjust segment lengths to account for all request lengths being 1 during decoding
            adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
            batch.adapter_meta.adapter_segments = torch.tensor(
                adapter_segments,
                dtype=torch.int32,
                device=batch.adapter_meta.adapter_segments.device,
            )

1941
1942
        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
1943
        next_token_ids = next_input_ids.tolist()
1944
        accepted_ids = accepted_ids.tolist()
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985

        # Update values if we need to continue prefilling
        # This represents the `else` case of the `Update values` if above
        # but since this require the `next_token_ids` to be on CPU, it is better to do it here
        if prefill and not finished_prefilling:
            # Speculation must be ignored while we prefill even with chunking
            # it simplifies everything
            assert batch.speculative_ids is None

            all_postfix_ids = []
            for i, (
                request_prefilling,
                next_token_id,
                all_input_ids,
                cache_length,
                input_length,
                next_chunk_length,
            ) in enumerate(
                zip(
                    batch.prefilling_mask,
                    next_token_ids,
                    batch.all_input_ids,
                    batch.cache_lengths,
                    batch.input_lengths,
                    next_chunk_lengths,
                )
            ):
                if request_prefilling:
                    next_cache_length = cache_length + input_length
                    # Get new prompt IDs to prefill
                    postfix_ids = all_input_ids[
                        next_cache_length : next_cache_length + next_chunk_length
                    ]
                else:
                    # This request is done prefilling, the new id is the one selected the sampling method
                    postfix_ids = [next_token_id]

                all_postfix_ids.append(postfix_ids)

            batch.input_ids = all_postfix_ids

1986
        start_decode = time.time_ns()
1987

1988
1989
1990
1991
        # Results
        generations: List[Generation] = []
        stopped = True

1992
1993
1994
        # Zipped iterator
        iterator = zip(
            batch.requests,
1995
1996
            batch.prompt_lengths,
            batch.cache_lengths,
1997
            batch.input_lengths,
1998
1999
            batch.prefix_offsets,
            batch.read_offsets,
2000
2001
            batch.stopping_criterias,
            batch.all_input_ids,
2002
2003
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
2004
            batch.top_n_tokens,
2005
2006
            current_prefilling_mask,
            batch.prefilling_mask,
Nicolas Patry's avatar
Nicolas Patry committed
2007
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
2008
2009
            batch_top_token_ids,
            batch_top_token_logprobs,
2010
2011
        )

2012
2013
        # Reset max_input_length
        batch.max_input_length = 0
2014
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
2015
        index = 0
2016
2017
        for i, (
            request,
2018
2019
            prompt_length,
            cache_length,
2020
            input_length,
2021
2022
            prefix_offset,
            read_offset,
2023
2024
            stopping_criteria,
            all_input_ids,
2025
2026
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
2027
            top_n_tokens,
2028
2029
            request_was_prefilling,
            request_is_prefilling,
Nicolas Patry's avatar
Nicolas Patry committed
2030
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
2031
2032
            top_token_ids,
            top_token_logprobs,
2033
        ) in enumerate(iterator):
2034
2035
2036
2037
2038
            # Compute logprobs first as, even though we might skip the token,
            # it can still be required to compute the logprobs
            # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
            # this state to be stable
            if request.id % self.world_size == self.rank:
2039
                # Prefill
2040
                if request_was_prefilling and request.prefill_logprobs:
2041
2042
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
                    if not request_is_prefilling:
                        # The request is dones prefilling, meaning that we started generating new tokens
                        # The last logprob is a logprob for a generated token that was not part of the prompt
                        # We need to remove it
                        out_end_index -= 1

                    request_prefill_logprobs = prefill_logprobs[
                        out_start_index:out_end_index
                    ]
                    # Logprobs generated by the model are for the next token
                    # So we need to translate the id tensor by 1
                    prefill_token_ids = all_input_ids[
                        cache_length + 1 : cache_length + input_length + 1
                    ]

                    past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]

                    if past_prefill_logprob_tokens is None:
                        # add nan for cached prompt tokens/first token
                        request_prefill_logprobs = [float("nan")] * (
                            cache_length + 1
                        ) + request_prefill_logprobs
                        prefill_token_ids = (
                            all_input_ids[: cache_length + 1] + prefill_token_ids
                        )
2068

2069
                    prefill_texts = self.tokenizer.batch_decode(
2070
                        prefill_token_ids,
2071
2072
2073
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
2074

2075
2076
                    prefill_logprob_tokens = Tokens(
                        prefill_token_ids,
OlivierDehaene's avatar
OlivierDehaene committed
2077
2078
2079
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
2080
                    )
2081
2082
2083
2084
2085
2086
                    if past_prefill_logprob_tokens is not None:
                        prefill_logprob_tokens = (
                            past_prefill_logprob_tokens + prefill_logprob_tokens
                        )

                    batch.prefill_logprob_tokens[i] = prefill_logprob_tokens
2087
                else:
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
                    batch.prefill_logprob_tokens[i] = None

            # If it is, the tokens we decoded should be ignored
            if request_is_prefilling:
                # Make sure that we do not stop as even though this request did not create a token, it is still
                # processing
                stopped = False
                new_input_length = next_chunk_lengths[i]
            else:
                new_input_length = n_accepted_ids
                # Append next token to all tokens
                next_token_texts = []
                left = 0

                if n_accepted_ids > 1:
                    log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")

                current_stopped = False
                for j in range(index, index + n_accepted_ids):
                    # Generated token
                    next_token_id = next_token_ids[j]
                    all_input_ids.append(next_token_id)
                    next_token_text, prefix_offset, read_offset = self.decode_token(
                        all_input_ids,
                        prefix_offset,
                        read_offset,
                    )
                    next_token_texts.append(next_token_text)

                    stop, reason = stopping_criteria(
                        next_token_id,
                        next_token_text,
                    )

                    if stop:
                        left = index + n_accepted_ids - j - 1
                        current_stopped = True
                        break
                    else:
                        current_stopped = False
                stopped = stopped and current_stopped

                _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
                _next_token_logprobs = next_token_logprobs[
                    index : index + n_accepted_ids - left
                ]

                # Shard generations
                # All generations will be appended in the rust sharded client
                if request.id % self.world_size == self.rank:
                    if stop:
                        # Decode generated tokens
                        output_text, _, _ = self.decode_token(
                            all_input_ids,
                            prefix_offset=len(all_input_ids)
                            - stopping_criteria.current_tokens
                            - 1,
                            read_offset=len(all_input_ids)
                            - stopping_criteria.current_tokens,
                            skip_special_tokens=True,
Nicolas Patry's avatar
Nicolas Patry committed
2148
                        )
2149
2150
2151
2152
2153
                        generated_text = GeneratedText(
                            output_text,
                            stopping_criteria.current_tokens,
                            reason,
                            seed if do_sample else None,
Nicolas Patry's avatar
Nicolas Patry committed
2154
                        )
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
                    else:
                        generated_text = None

                    if top_n_tokens > 0:
                        all_top_tokens = []
                        for top_token_ids, top_token_logprobs in zip(
                            top_token_ids, top_token_logprobs
                        ):
                            toptoken_texts = self.tokenizer.batch_decode(
                                top_token_ids,
                                clean_up_tokenization_spaces=False,
                                skip_special_tokens=False,
                            )
                            special_toptokens = [
                                token_id in self.all_special_ids
                                for token_id in top_token_ids
                            ]
                            top_tokens = Tokens(
                                top_token_ids,
                                top_token_logprobs,
                                toptoken_texts,
                                special_toptokens,
                            )
                            all_top_tokens.append(top_tokens)
                        top_tokens = all_top_tokens
                    else:
                        top_tokens = None

                    generation = Generation(
                        request.id,
                        batch.prefill_logprob_tokens[i],
                        Tokens(
                            _next_token_ids,
                            _next_token_logprobs,
                            next_token_texts,
                            [nid in self.all_special_ids for nid in _next_token_ids],
                        ),
                        generated_text,
                        top_tokens,
                    )
2195

2196
                    generations.append(generation)
2197

2198
2199
2200
2201
2202
2203
2204
2205
                # accept each new token for this specific request since we may
                # have more than one new token per request with speculative decoding
                for next_token_id in _next_token_ids:
                    batch.next_token_chooser = (
                        batch.next_token_chooser.advance_grammar_single(
                            i, next_token_id
                        )
                    )
drbh's avatar
drbh committed
2206

2207
            # Update values
2208
2209
2210
2211
2212
2213
2214
2215
2216
            index += n_accepted_ids
            current_cache_length = cache_length + input_length
            batch.cache_lengths[i] = current_cache_length
            current_input_length = new_input_length
            batch.max_input_length = max(batch.max_input_length, current_input_length)
            batch.input_lengths[i] = current_input_length
            current_length = current_cache_length + current_input_length
            batch.max_current_length = max(batch.max_current_length, current_length)

2217
2218
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
2219
2220
            batch.all_input_ids[i] = all_input_ids

2221
2222
        if stopped:
            # No need to return a batch if we know that all requests stopped
2223
2224
2225
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
2226

2227
2228
2229
2230
2231
2232
2233
        if prefill and finished_prefilling:
            # We do not need prefill tensors anymore
            batch.cu_seqlen_prefill = None
            batch.prefill_cache_indices = None
            batch.prefill_cu_outlens = None
            batch.prefill_head_indices = None
            batch.prefill_next_token_indices = None
2234

2235
2236
2237
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)
2238
2239
2240
2241
2242
2243

    def _forward_context(
        self,
        *,
        block_tables: torch.Tensor,
        cu_seqlen_prefill: Optional[torch.Tensor],
Nicolas Patry's avatar
Nicolas Patry committed
2244
        input_lengths_tensor: torch.Tensor,
2245
        cache_lengths_tensor: torch.Tensor,
2246
2247
        state: Optional[Any] = None,
    ) -> ContextManager:
2248
        if ATTENTION != "flashinfer":
2249
2250
            return nullcontext()

Nicolas Patry's avatar
Nicolas Patry committed
2251
        from text_generation_server.layers.attention.flashinfer import (
2252
            use_decode_state,
Nicolas Patry's avatar
Nicolas Patry committed
2253
            use_prefill_with_paged_kv_state,
2254
2255
2256
        )

        if cu_seqlen_prefill is not None:
Nicolas Patry's avatar
Nicolas Patry committed
2257
2258
2259
2260
2261
2262
2263
            return use_prefill_with_paged_kv_state(
                state=(
                    state if state is not None else self.prefill_with_paged_kv_state
                ),
                # block_tables=block_tables_to_ragged(
                #     block_tables=block_tables,
                #     input_lengths=input_lengths,
2264
                #     cache_lengths=cache_lengths,
Nicolas Patry's avatar
Nicolas Patry committed
2265
2266
                # ),
                block_tables=block_tables,
2267
                cu_seqlens=cu_seqlen_prefill,
2268
                input_lengths=input_lengths_tensor + cache_lengths_tensor,
2269
2270
2271
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
Nicolas Patry's avatar
Nicolas Patry committed
2272
                page_size=BLOCK_SIZE,
2273
2274
                dtype=self.dtype,
                window_left=self.sliding_window,
2275
2276
            )
        else:
Nicolas Patry's avatar
Nicolas Patry committed
2277
            assert input_lengths_tensor is not None
2278
2279
            return use_decode_state(
                state=state if state is not None else self.decode_state,
2280
                input_lengths=input_lengths_tensor + cache_lengths_tensor,
Nicolas Patry's avatar
Nicolas Patry committed
2281
                block_tables=block_tables,
2282
2283
2284
2285
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                page_size=BLOCK_SIZE,
2286
2287
                dtype=self.dtype,
                window_left=self.sliding_window,
2288
            )
Nicolas Patry's avatar
Nicolas Patry committed
2289
2290
2291


def block_tables_to_ragged(
2292
    *, block_tables: torch.Tensor, input_lengths: List[int], cache_lengths: List[int]
Nicolas Patry's avatar
Nicolas Patry committed
2293
2294
) -> torch.Tensor:
    """Convert block table to ragged format compatible with FlashInfer."""
2295
    assert len(input_lengths) == len(cache_lengths)
Nicolas Patry's avatar
Nicolas Patry committed
2296

2297
    total_len = sum(input_lengths) + sum(cache_lengths)
Nicolas Patry's avatar
Nicolas Patry committed
2298
2299
2300
2301
2302
    block_tables_ragged = torch.empty(
        total_len, dtype=torch.int32, device=block_tables.device
    )

    offset = 0
2303
2304
    for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)):
        seq_len = cache_length + input_length
Nicolas Patry's avatar
Nicolas Patry committed
2305
2306
2307
2308
        block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
        offset += seq_len

    return block_tables_ragged