flash_causal_lm.py 92.2 KB
Newer Older
jixx's avatar
jixx committed
1
from contextlib import nullcontext
jixx's avatar
init  
jixx committed
2
3
4
5
6
7
8
9
10
11
12
import math
import os
import time
import torch
import torch.distributed

import numpy as np

from loguru import logger
from dataclasses import dataclass
from opentelemetry import trace
jixx's avatar
jixx committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from transformers import (
    PreTrainedTokenizerBase,
    AutoConfig,
    AutoTokenizer,
    GenerationConfig,
)
from typing import (
    Any,
    ContextManager,
    Iterable,
    Optional,
    Tuple,
    List,
    Type,
    Dict,
    Union,
)
jixx's avatar
init  
jixx committed
30
31
32
33
34
35

from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model
jixx's avatar
jixx committed
36
37
38
39
40
from text_generation_server.utils.log import log_master
from text_generation_server.utils.prefill_chunking import (
    get_support_chunking,
    get_max_prefill_tokens,
)
jixx's avatar
init  
jixx committed
41
42
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.speculate import get_speculate
jixx's avatar
jixx committed
43
44
45
46
47
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
)
jixx's avatar
init  
jixx committed
48
49
50
51
52
53
54
55
56
from text_generation_server.models.types import (
    Batch,
    Tokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import (
    MEM_POOL,
jixx's avatar
jixx committed
57
    ATTENTION,
jixx's avatar
init  
jixx committed
58
59
    BLOCK_SIZE,
    CUDA_GRAPHS,
jixx's avatar
jixx committed
60
    TGI_WIGGLE_ROOM,
jixx's avatar
init  
jixx committed
61
62
    get_adapter_to_index,
)
jixx's avatar
jixx committed
63
from text_generation_server.layers.attention import KVCache, Seqlen
jixx's avatar
init  
jixx committed
64
65
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
jixx's avatar
jixx committed
66
from text_generation_server.utils.quantization import get_loader
jixx's avatar
init  
jixx committed
67
68
69
70
71
72
73
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments

from text_generation_server.utils.import_utils import (
    empty_cache,
    synchronize,
    get_free_memory,
)
jixx's avatar
jixx committed
74
75
76
77
78
79
80
81
from text_generation_server.models.metadata_kernels import (
    has_triton,
    copy_next_input_ids_inplace,
    block_tables_to_ragged,
    block_tables_to_padded,
    prepare_position_slot_ids,
    slots_filtering,
)
jixx's avatar
init  
jixx committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

tracer = trace.get_tracer(__name__)

# Will be set in init
SLIDING_WINDOW: Optional[int] = None


def set_sliding_window(sliding_window: int):
    global SLIDING_WINDOW
    SLIDING_WINDOW = sliding_window


def get_sliding_windows() -> int:
    global SLIDING_WINDOW
    return SLIDING_WINDOW


jixx's avatar
jixx committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def init_cpu_threads_env(rank_id: int, world_size: int):
    import importlib.util

    if importlib.util.find_spec("numa") is not None:
        import numa
        import psutil

        nodes = numa.info.get_max_node() + 1
        rank_per_node = math.ceil(world_size / nodes)
        num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes)
        node_id = int(rank_id / rank_per_node)
        rank_offset_per_node = rank_id % rank_per_node
        if os.getenv("OMP_NUM_THREADS") is None:
            num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1)
        else:
            num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS"))
        if len(numa.memory.get_membind_nodes()) == nodes:
            numa.memory.set_membind_nodes((node_id))
        torch.set_num_threads(num_cpus_per_rank)
        if len(numa.schedule.get_affinitive_cpus(0)) == psutil.cpu_count(logical=True):
            cpu_start = num_cpus_per_rank * rank_offset_per_node
            numa.schedule.run_on_cpus(
                0,
                *(
                    numa.info.node_to_cpus(node_id)[
                        cpu_start : cpu_start + num_cpus_per_rank
                    ]
                ),
            )
        logger.info(
            f"affinity={numa.schedule.get_affinitive_cpus(0)}, membind = {numa.memory.get_membind_nodes()}"
        )


jixx's avatar
init  
jixx committed
133
134
135
136
137
138
139
140
@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]

    # Decoder values
jixx's avatar
jixx committed
141
142
143
144
145
    # Can be a list for easy filtering
    # If `input_ids` is a list, it needs to be materialized to a tensor first
    input_ids: Union[torch.Tensor, List[List[int]]]
    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
    position_ids: Optional[torch.Tensor]
jixx's avatar
init  
jixx committed
146
147
148
149
    speculative_ids: Optional[torch.Tensor]

    # Set when creating the batch
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
jixx's avatar
jixx committed
150
151
    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
    slot_indices: Optional[torch.Tensor]
jixx's avatar
init  
jixx committed
152
153
154
155
156
157
158

    # list of length b of list of length s_i // block_size
    block_tables: List[List[int]]
    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
    block_tables_tensor: torch.Tensor
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
    slots: torch.Tensor
jixx's avatar
jixx committed
159
160
161
    # list of length b + 1  containing the cumulative sequence slot lengths of the sequences in the batch
    # used for filtering
    cu_slots: torch.Tensor
jixx's avatar
init  
jixx committed
162

jixx's avatar
jixx committed
163
164
165
166
167
168
169
    max_input_length: int
    max_current_length: int

    # Whether this batch contains at least one request that is prefilling
    prefilling: bool
    # Whether each request is prefilling
    prefilling_mask: List[bool]
jixx's avatar
init  
jixx committed
170
171

    # Prefill metadata tensors to efficiently compute logprobs
jixx's avatar
jixx committed
172
173
174
175
176
177
    # tensor of length b + 1  containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
    # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
    # as we only keep SLIDING_WINDOW values instead of the whole tensor
    prefill_cache_indices: Optional[torch.Tensor]
    # Will be set by `generate_token` and reset after each prefill forward
jixx's avatar
init  
jixx committed
178
    prefill_head_indices: Optional[torch.Tensor]
jixx's avatar
jixx committed
179
    # Will be set by `generate_token` and reset after each prefill forward
jixx's avatar
init  
jixx committed
180
    prefill_next_token_indices: Optional[torch.tensor]
jixx's avatar
jixx committed
181
    # Will be set by `generate_token` and reset after each prefill forward
jixx's avatar
init  
jixx committed
182
    prefill_cu_outlens: Optional[List[int]]
jixx's avatar
jixx committed
183
184
    # Will be set by `generate_token` and reset after each prefill forward
    prefill_logprob_tokens: List[Optional[Tokens]]
jixx's avatar
init  
jixx committed
185
186
187
188
189
190
191

    # All tokens
    all_input_ids: List[List[int]]
    all_input_ids_tensor: torch.Tensor

    # Lengths of all generations present in the batch
    input_lengths: List[int]
jixx's avatar
jixx committed
192
193
194
195
196
197
198
199
    # size [b], containing the number of blocks that can be retrieved from the cache
    cache_lengths: List[int]
    prompt_lengths: List[int]
    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
    input_lengths_tensor: Optional[torch.Tensor]
    cache_lengths_tensor: Optional[torch.Tensor]
    prompt_lengths_tensor: torch.Tensor

jixx's avatar
init  
jixx committed
200
201
202
203
204
205
206
207
208
209
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]

    # Generation helpers
    next_token_chooser: HeterogeneousNextTokenChooser
    stopping_criterias: List[StoppingCriteria]
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor

    # Adapter metadata for each request
jixx's avatar
jixx committed
210
211
    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
    adapter_meta: Optional[AdapterBatchMetadata]
jixx's avatar
init  
jixx committed
212
213
214
215
216
217
218
219
220
221
222
223

    # Number of blocks in this batch
    num_blocks: int
    # Maximum number of blocks
    max_blocks: int

    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
            id=self.batch_id,
            request_ids=[r.id for r in self.requests],
            size=len(self),
            max_tokens=self.num_blocks * BLOCK_SIZE,
jixx's avatar
jixx committed
224
225
226
227
228
            current_tokens=(
                sum([len(i) for i in self.input_ids])
                if isinstance(self.input_ids, list)
                else len(self.input_ids)
            ),
jixx's avatar
init  
jixx committed
229
230
231
232
233
234
        )

    @classmethod
    def batch_tokenized_inputs(
        cls, requests: Iterable[generate_pb2.Request], tokenizer
    ):
jixx's avatar
jixx committed
235
236
237
        max_length = 0
        all_input_ids = []
        batch_size = 0
jixx's avatar
init  
jixx committed
238
        for r in requests:
jixx's avatar
jixx committed
239
240
241
242
243
244
245
246
247
248
249
            batch_size += 1
            inputs = concat_text_chunks(r.input_chunks.chunks)
            input_ids = tokenizer(
                inputs,
                truncation=True,
                max_length=r.truncate,
                add_special_tokens=r.add_special_tokens,
            )["input_ids"]
            max_length = max(max_length, len(input_ids))
            all_input_ids.append(input_ids)
        return all_input_ids
jixx's avatar
init  
jixx committed
250
251
252
253
254
255
256
257
258
259

    @classmethod
    def from_tokenized(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        batch_tokenized_inputs,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
jixx's avatar
jixx committed
260
        speculate = get_speculate()
jixx's avatar
init  
jixx committed
261

jixx's avatar
jixx committed
262
        cache_lengths = []
jixx's avatar
init  
jixx committed
263
        input_lengths = []
jixx's avatar
jixx committed
264
        prompt_lengths = []
jixx's avatar
init  
jixx committed
265
266
267
        prefix_offsets = []
        read_offsets = []
        all_input_ids = []
jixx's avatar
jixx committed
268
        all_postfix_ids = []
jixx's avatar
init  
jixx committed
269
        requests_idx_mapping = {}
jixx's avatar
jixx committed
270
271
        slots = []
        cu_slots = [0]
jixx's avatar
init  
jixx committed
272
273
274
275
276
277

        next_token_chooser_parameters = []
        stopping_criterias = []
        top_n_tokens = []

        num_blocks = 0
jixx's avatar
jixx committed
278
279
        max_input_length = 0
        max_current_length = 0
jixx's avatar
init  
jixx committed
280
281
282
        max_length = 0
        max_blocks = 0

jixx's avatar
jixx committed
283
        cu_blocks = [0]
jixx's avatar
init  
jixx committed
284
        block_tables = []
jixx's avatar
jixx committed
285
        block_tables_ragged = []
jixx's avatar
init  
jixx committed
286
287
288
289
290
291
292
293

        # Parse batch
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

jixx's avatar
jixx committed
294
295
            prompt_length = len(tokenized_input)
            prompt_lengths.append(prompt_length)
jixx's avatar
init  
jixx committed
296

jixx's avatar
jixx committed
297
            cache_length = r.cache_len
jixx's avatar
init  
jixx committed
298

jixx's avatar
jixx committed
299
300
301
302
303
            assert (
                cache_length <= prompt_length
            ), f"Prefix {cache_length} vs input {prompt_length}"
            if cache_length == prompt_length:
                assert False, "unreachable"
jixx's avatar
init  
jixx committed
304

jixx's avatar
jixx committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
            # `chunk_len` is an optional field in the protobuf
            # It is only set if the model support chunking
            if r.HasField("chunk_len"):
                input_length = r.chunk_len

                if cache_length + input_length < prompt_length:
                    # FIXME: speculate is not supported for context chunking at the moment
                    assert speculate == 0
                    assert get_support_chunking()
                    assert input_length > 0

                postfix_ids = tokenized_input[
                    cache_length : cache_length + input_length
                ]
                assert (
                    len(postfix_ids) == input_length
                ), "Rust and Python tokenizers are not aligned"
            else:
                # Use all the remaining ids
                postfix_ids = tokenized_input[cache_length:]
                input_length = len(postfix_ids)

            input_lengths.append(input_length)
jixx's avatar
init  
jixx committed
328

jixx's avatar
jixx committed
329
330
            prefix_offsets.append(prompt_length - 5)
            read_offsets.append(prompt_length)
jixx's avatar
init  
jixx committed
331

jixx's avatar
jixx committed
332
333
            all_postfix_ids.append(postfix_ids)
            all_input_ids.append(tokenized_input)
jixx's avatar
init  
jixx committed
334
335
336
337
338
339
340
341
342
343
344
345
346
347

            next_token_chooser_parameters.append(r.parameters)

            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            max_new_tokens = stopping_criteria.max_new_tokens
            stopping_criterias.append(stopping_criteria)
            top_n_tokens.append(r.top_n_tokens)

            # Paged attention
            # Remove one as the first token des not have a past
            speculative_length = get_speculate()
            speculative_length = 0 if speculative_length is None else speculative_length
jixx's avatar
jixx committed
348
349
350

            # Tokens that need to be mapped to blocks.
            block_tokens = prompt_length + max_new_tokens - 1 + speculative_length
jixx's avatar
init  
jixx committed
351
352
353

            # blocks and slots can be empty (for example in warmup)
            if not r.blocks:
jixx's avatar
jixx committed
354
                needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
jixx's avatar
init  
jixx committed
355
356
357
358
359
360
361
362
363
364
365
366
367
                request_blocks = [
                    b for b in range(num_blocks, num_blocks + needed_blocks)
                ]
                request_slots = [
                    s
                    for b in request_blocks
                    for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
                ]
            else:
                request_blocks = r.blocks
                request_slots = r.slots

            block_tables.append(request_blocks)
jixx's avatar
jixx committed
368
369
            block_tables_ragged.extend(request_blocks)
            cu_blocks.append(len(block_tables_ragged))
jixx's avatar
init  
jixx committed
370

jixx's avatar
jixx committed
371
372
            slots.extend(request_slots)
            cu_slots.append(len(slots))
jixx's avatar
init  
jixx committed
373

jixx's avatar
jixx committed
374
375
            cache_lengths.append(cache_length)
            num_blocks += len(request_blocks)
jixx's avatar
init  
jixx committed
376
377
378

            # Update
            max_blocks = max(max_blocks, len(request_blocks))
jixx's avatar
jixx committed
379
380
            max_input_length = max(max_input_length, input_length)
            max_current_length = max(max_current_length, cache_length + input_length)
jixx's avatar
init  
jixx committed
381
            max_length = max(
jixx's avatar
jixx committed
382
383
                max_length,
                prompt_length + max_new_tokens + speculative_length,
jixx's avatar
init  
jixx committed
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
            )

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
            next_token_chooser_parameters, dtype, device, tokenizer
        )

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids

        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

jixx's avatar
jixx committed
402
403
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
jixx's avatar
init  
jixx committed
404
405
        )

jixx's avatar
jixx committed
406
407
408
409
410
411
412
413
        block_tables_ragged = torch.tensor(
            block_tables_ragged, device=device, dtype=torch.int32
        )
        cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64)
        block_tables_tensor = torch.empty(
            (len(block_tables), max_blocks),
            device=device,
            dtype=torch.int32,
jixx's avatar
init  
jixx committed
414
415
        )

jixx's avatar
jixx committed
416
417
418
419
        # If the device supports Triton, we can use a fused kernel
        if has_triton():
            block_tables_to_padded(
                max_blocks, cu_blocks, block_tables_tensor, block_tables_ragged
jixx's avatar
init  
jixx committed
420
            )
jixx's avatar
jixx committed
421
422
423
424
425
426
427
428
        else:
            for i, request_blocks in enumerate(block_tables):
                block_tables_tensor[i, : len(request_blocks)] = torch.tensor(
                    request_blocks
                )

        prompt_lengths_tensor = torch.tensor(
            prompt_lengths, dtype=torch.int32, device=device
jixx's avatar
init  
jixx committed
429
430
431
        )

        slots = torch.tensor(slots, dtype=torch.int64, device=device)
jixx's avatar
jixx committed
432
        cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
jixx's avatar
init  
jixx committed
433
434
435
436
437

        return cls(
            batch_id=pb.id,
            requests=pb.requests,
            requests_idx_mapping=requests_idx_mapping,
jixx's avatar
jixx committed
438
            input_ids=all_postfix_ids,
jixx's avatar
init  
jixx committed
439
440
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
jixx's avatar
jixx committed
441
442
443
444
445
446
            cache_lengths=cache_lengths,
            max_input_length=max_input_length,
            max_current_length=max_current_length,
            prefilling=True,
            prefilling_mask=[True] * len(pb.requests),
            prefill_logprob_tokens=[None] * len(pb.requests),
jixx's avatar
init  
jixx committed
447
            input_lengths=input_lengths,
jixx's avatar
jixx committed
448
            prompt_lengths=prompt_lengths,
jixx's avatar
init  
jixx committed
449
450
451
452
453
454
455
456
457
458
459
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
            stopping_criterias=stopping_criterias,
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
            num_blocks=num_blocks,
            max_blocks=max_blocks,
            speculative_ids=None,
jixx's avatar
jixx committed
460
461
462
463
464
465
466
467
468
469
470
471
472
473
            prompt_lengths_tensor=prompt_lengths_tensor,
            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
            position_ids=None,
            cu_seqlen_prefill=None,
            prefill_cache_indices=None,
            slot_indices=None,
            slots=slots,
            cu_slots=cu_slots,
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
            cache_lengths_tensor=None,
            input_lengths_tensor=None,
            adapter_meta=None,
jixx's avatar
init  
jixx committed
474
475
476
477
478
479
480
481
482
483
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
jixx's avatar
jixx committed
484
        assert len(pb.requests) > 0
jixx's avatar
init  
jixx committed
485
486
487
488
489
490
491
492
493
494
495
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
        return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)

    @tracer.start_as_current_span("filter")
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
        if len(request_ids) == len(self):
            return self

jixx's avatar
jixx committed
496
        device = self.block_tables_tensor.device
jixx's avatar
init  
jixx committed
497
498
499
500
501
502
503

        # New values after filtering
        requests_idx_mapping = {}

        # Used to index into tensors
        indices = []

jixx's avatar
jixx committed
504
505
506
507
508
        if not has_triton():
            # slots to keep after filtering
            slot_filtering_indices = torch.zeros(
                self.slots.shape[0], dtype=torch.bool, device=device
            )
jixx's avatar
init  
jixx committed
509
510
511

        # Create on CPU to only move to GPU once instead of at every copy
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
jixx's avatar
jixx committed
512
513
        max_input_length = 0
        max_current_length = 0
jixx's avatar
init  
jixx committed
514
515
516
517

        requests = []
        block_tables = []
        all_input_ids = []
jixx's avatar
jixx committed
518
        input_ids = []
jixx's avatar
init  
jixx committed
519

jixx's avatar
jixx committed
520
        prompt_lengths = []
jixx's avatar
init  
jixx committed
521
        input_lengths = []
jixx's avatar
jixx committed
522
        cache_lengths = []
jixx's avatar
init  
jixx committed
523
524
        prefix_offsets = []
        read_offsets = []
jixx's avatar
jixx committed
525
526
527
528
        cu_slots = [0]

        prefilling_mask = []
        prefill_logprob_tokens = []
jixx's avatar
init  
jixx committed
529
530
531
532
533
534
535

        stopping_criterias = []
        top_n_tokens = []
        adapter_set = set()

        num_blocks = 0
        max_blocks = 0
jixx's avatar
jixx committed
536
537
        max_slots = 0
        cumulative_slot_tokens = 0
jixx's avatar
init  
jixx committed
538
539
540
541
542
543
544
545

        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
            indices.append(idx)
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])

jixx's avatar
jixx committed
546
547
548
549
            # Prefilling
            request_prefilling = self.prefilling_mask[idx]
            prefilling_mask.append(request_prefilling)

jixx's avatar
init  
jixx committed
550
551
            # Get length
            request_input_length = self.input_lengths[idx]
jixx's avatar
jixx committed
552
553
554
555
556
            request_cache_length = self.cache_lengths[idx]
            max_input_length = max(max_input_length, request_input_length)
            max_current_length = max(
                max_current_length, request_cache_length + request_input_length
            )
jixx's avatar
init  
jixx committed
557
558
559

            all_input_ids.append(self.all_input_ids[idx])

jixx's avatar
jixx committed
560
            prompt_lengths.append(self.prompt_lengths[idx])
jixx's avatar
init  
jixx committed
561
            input_lengths.append(request_input_length)
jixx's avatar
jixx committed
562
            cache_lengths.append(request_cache_length)
jixx's avatar
init  
jixx committed
563
564
565
566
567
568
569
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])

            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)

            top_n_tokens.append(self.top_n_tokens[idx])
jixx's avatar
jixx committed
570
            prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
jixx's avatar
init  
jixx committed
571
572
573
574
575
576
577
578
579

            ADAPTER_TO_INDEX = get_adapter_to_index()
            adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
            adapter_set.add(adapter_index)

            request_block_table = self.block_tables[idx]
            num_blocks += len(request_block_table)
            block_tables.append(request_block_table)

jixx's avatar
jixx committed
580
581
582
583
584
585
586
            start_slot = self.cu_slots[idx]
            end_slot = self.cu_slots[idx + 1]
            slot_length = end_slot - start_slot

            if not has_triton():
                # Set slice
                slot_filtering_indices[start_slot:end_slot] = True
jixx's avatar
init  
jixx committed
587

jixx's avatar
jixx committed
588
            cu_slots.append(cumulative_slot_tokens + slot_length)
jixx's avatar
init  
jixx committed
589

jixx's avatar
jixx committed
590
591
592
593
594
595
596
            # Input ids if the request was part of a prefilling batch
            # If the batch was decoding we can index into the tensor directly later
            if self.prefilling:
                input_ids.append(self.input_ids[idx])
            else:
                # Copy to tensor (CPU)
                slot_indices[i] = cumulative_slot_tokens + request_cache_length
jixx's avatar
init  
jixx committed
597

jixx's avatar
jixx committed
598
            cumulative_slot_tokens += slot_length
jixx's avatar
init  
jixx committed
599
            max_blocks = max(max_blocks, len(request_block_table))
jixx's avatar
jixx committed
600
            max_slots = max(max_slots, slot_length)
jixx's avatar
init  
jixx committed
601
602
603
604
605
606
607
608

        all_input_ids_tensor = self.all_input_ids_tensor[indices]
        block_tables_tensor = self.block_tables_tensor[indices]
        next_token_chooser = self.next_token_chooser.filter(indices)
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
jixx's avatar
jixx committed
609
        prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
jixx's avatar
init  
jixx committed
610

jixx's avatar
jixx committed
611
        cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
jixx's avatar
init  
jixx committed
612

jixx's avatar
jixx committed
613
614
615
616
617
618
619
620
621
        if not has_triton():
            slots = self.slots[slot_filtering_indices]
        else:
            slots = self.slots.new_empty(cumulative_slot_tokens)
            gpu_cu_slots = cu_slots.to(device)
            slots_indexing_start = self.cu_slots.to(device)[indices]
            slots_filtering(
                max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start
            )
jixx's avatar
init  
jixx committed
622

jixx's avatar
jixx committed
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
        if self.prefilling:
            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
            position_ids = None
            slot_indices = None
            cache_lengths_tensor = None
            input_lengths_tensor = None
            adapter_meta = None
        else:
            # Index into tensors
            input_ids = self.input_ids[indices]
            position_ids = self.position_ids[indices]
            adapter_indices = self.adapter_meta.adapter_indices[indices]
            input_lengths_tensor = self.input_lengths_tensor[indices]
            cache_lengths_tensor = self.cache_lengths_tensor[indices]

            # Move to GPU now that we have the whole tensor
            slot_indices = slot_indices.to(device)

            adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
            adapter_segments = torch.tensor(
                adapter_segments, dtype=torch.int32, device=device
            )
            adapter_meta = AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            )
jixx's avatar
init  
jixx committed
651
652
653
654
655
656
657
658
659
660
661
662
663

        return type(self)(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=None,
            prefill_cache_indices=None,
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
jixx's avatar
jixx committed
664
665
666
667
668
            cu_slots=cu_slots,
            max_input_length=max_input_length,
            max_current_length=max_current_length,
            prefilling=self.prefilling,
            prefilling_mask=prefilling_mask,
jixx's avatar
init  
jixx committed
669
670
671
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
jixx's avatar
jixx committed
672
673
674
            prefill_logprob_tokens=prefill_logprob_tokens,
            prompt_lengths=prompt_lengths,
            prompt_lengths_tensor=prompt_lengths_tensor,
jixx's avatar
init  
jixx committed
675
676
            input_lengths=input_lengths,
            input_lengths_tensor=input_lengths_tensor,
jixx's avatar
jixx committed
677
678
            cache_lengths=cache_lengths,
            cache_lengths_tensor=cache_lengths_tensor,
jixx's avatar
init  
jixx committed
679
680
681
682
683
684
685
686
687
688
689
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
            stopping_criterias=stopping_criterias,
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
            num_blocks=num_blocks,
            max_blocks=max_blocks,
            speculative_ids=speculative_ids,
jixx's avatar
jixx committed
690
            adapter_meta=adapter_meta,
jixx's avatar
init  
jixx committed
691
692
693
694
695
696
697
698
699
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

jixx's avatar
jixx committed
700
        prefilling = False
jixx's avatar
init  
jixx committed
701
702
703
704
705
        num_blocks = 0
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
jixx's avatar
jixx committed
706
707
        max_input_length = 0
        max_current_length = 0
jixx's avatar
init  
jixx committed
708
709
        for b in batches:
            total_batch_size += len(b)
jixx's avatar
jixx committed
710
            max_blocks = max(max_blocks, b.max_blocks)
jixx's avatar
init  
jixx committed
711
712
713
714
715
            total_slots += len(b.slots)
            num_blocks += b.num_blocks
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
jixx's avatar
jixx committed
716
717
            max_input_length = max(max_input_length, b.max_input_length)
            max_current_length = max(max_current_length, b.max_current_length)
jixx's avatar
init  
jixx committed
718
719
720
            max_length = max(
                max_length,
                max(
jixx's avatar
jixx committed
721
                    prompt_length
jixx's avatar
init  
jixx committed
722
723
                    + stopping_criteria.max_new_tokens
                    + speculative_length
jixx's avatar
jixx committed
724
725
                    for prompt_length, stopping_criteria in zip(
                        b.prompt_lengths, b.stopping_criterias
jixx's avatar
init  
jixx committed
726
727
728
                    )
                ),
            )
jixx's avatar
jixx committed
729
            prefilling = prefilling or b.prefilling
jixx's avatar
init  
jixx committed
730
731

        slots = batches[0].slots.new_empty(total_slots)
jixx's avatar
jixx committed
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
        cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64)
        if prefilling:
            input_ids = []
            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
            position_ids = None
            slot_indices = None
            cache_lengths_tensor = None
            input_lengths_tensor = None
            adapter_meta = None
            adapter_segment_builder = None
        else:
            input_ids = batches[0].input_ids.new_empty(total_batch_size)
            position_ids = batches[0].position_ids.new_empty(total_batch_size)
            slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
            input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
                total_batch_size
            )
            cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
                total_batch_size
            )
            total_indices_size = sum(
                b.adapter_meta.adapter_indices.shape[0] for b in batches
            )
            adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
                total_indices_size
            )
            adapter_segment_builder = SegmentConcatBuilder()
            adapter_set = set()

        prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
jixx's avatar
init  
jixx committed
762
763
764
765
766
767
768
769
770
771
772
773
774
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
        )
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )

        block_tables = []
jixx's avatar
jixx committed
775
        cache_lengths = []
jixx's avatar
init  
jixx committed
776
777
        all_input_ids = []

jixx's avatar
jixx committed
778
        prompt_lengths = []
jixx's avatar
init  
jixx committed
779
780
781
782
        input_lengths = []
        prefix_offsets = []
        read_offsets = []

jixx's avatar
jixx committed
783
784
        prefill_logprob_tokens = []

jixx's avatar
init  
jixx committed
785
786
787
788
        next_token_chooser_parameters = []
        fsm_grammar_states = []
        stopping_criterias = []
        top_n_tokens = []
jixx's avatar
jixx committed
789
        prefilling_mask = []
jixx's avatar
init  
jixx committed
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817

        # Cumulative length
        cumulative_batch_size = 0
        cumulative_slots = 0
        cumulative_adapter_indices_size = 0

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)

            # Copy tensors (GPU)
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]

            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
jixx's avatar
jixx committed
818
819
820
821
822
823
824
825
            prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor

            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
            slots[slots_start_index:slots_end_index] = batch.slots
            cu_slots[start_index + 1 : end_index + 1] = (
                batch.cu_slots[1:] + cumulative_slots
            )
jixx's avatar
init  
jixx committed
826

jixx's avatar
jixx committed
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
            if not prefilling:
                input_ids[start_index:end_index] = batch.input_ids
                position_ids[start_index:end_index] = batch.position_ids
                slot_indices[start_index:end_index] = (
                    batch.slot_indices + cumulative_slots
                )
                input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
                cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor

                # Copy over adapter indices
                adapter_start_index = cumulative_adapter_indices_size
                adapter_end_index = (
                    cumulative_adapter_indices_size
                    + batch.adapter_meta.adapter_indices.shape[0]
                )
                adapter_indices[adapter_start_index:adapter_end_index] = (
                    batch.adapter_meta.adapter_indices
                )
                cumulative_adapter_indices_size = adapter_end_index
                adapter_set.update(batch.adapter_meta.adapter_set)
                adapter_segment_builder.concat(
                    batch.adapter_meta.adapter_segments,
                    batch.adapter_meta.segment_indices,
                )
            else:
                if isinstance(batch.input_ids, torch.Tensor):
                    batch.input_ids = batch.input_ids.view(-1, 1).tolist()
                input_ids.extend(batch.input_ids)
jixx's avatar
init  
jixx committed
855

jixx's avatar
jixx committed
856
            prefilling_mask.extend(batch.prefilling_mask)
jixx's avatar
init  
jixx committed
857
            block_tables.extend(batch.block_tables)
jixx's avatar
jixx committed
858
            cache_lengths.extend(batch.cache_lengths)
jixx's avatar
init  
jixx committed
859
860
            all_input_ids.extend(batch.all_input_ids)

jixx's avatar
jixx committed
861
            prompt_lengths.extend(batch.prompt_lengths)
jixx's avatar
init  
jixx committed
862
863
864
865
            input_lengths.extend(batch.input_lengths)
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)

jixx's avatar
jixx committed
866
867
            prefill_logprob_tokens.extend(batch.prefill_logprob_tokens)

jixx's avatar
init  
jixx committed
868
869
870
871
872
873
874
875
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
            stopping_criterias.extend(batch.stopping_criterias)

            top_n_tokens.extend(batch.top_n_tokens)

            # Update
            cumulative_slots += len(batch.slots)
jixx's avatar
jixx committed
876
            cumulative_batch_size += len(batch)
jixx's avatar
init  
jixx committed
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
            tokenizer=batches[0].next_token_chooser.tokenizer,
            fsm_grammar_states=fsm_grammar_states,
        )

        speculative_ids = (
            torch.cat([b.speculative_ids for b in batches], dim=0)
            if batches[0].speculative_ids is not None
            else None
        )

jixx's avatar
jixx committed
892
893
894
895
896
897
898
899
        if adapter_segment_builder is not None:
            adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
            adapter_meta = AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            )
jixx's avatar
init  
jixx committed
900
901
902
903
904
905
906
907
908
909
910
911

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=None,
            prefill_cache_indices=None,
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
jixx's avatar
jixx committed
912
913
            cache_lengths=cache_lengths,
            cache_lengths_tensor=cache_lengths_tensor,
jixx's avatar
init  
jixx committed
914
            slots=slots,
jixx's avatar
jixx committed
915
916
917
918
919
            cu_slots=cu_slots,
            max_input_length=max_input_length,
            max_current_length=max_current_length,
            prefilling=prefilling,
            prefilling_mask=prefilling_mask,
jixx's avatar
init  
jixx committed
920
921
922
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
jixx's avatar
jixx committed
923
924
925
            prefill_logprob_tokens=prefill_logprob_tokens,
            prompt_lengths=prompt_lengths,
            prompt_lengths_tensor=prompt_lengths_tensor,
jixx's avatar
init  
jixx committed
926
927
928
929
930
931
932
933
934
935
936
937
938
            input_lengths=input_lengths,
            input_lengths_tensor=input_lengths_tensor,
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
            stopping_criterias=stopping_criterias,
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
            num_blocks=num_blocks,
            max_blocks=max_blocks,
            speculative_ids=speculative_ids,
jixx's avatar
jixx committed
939
            adapter_meta=adapter_meta,
jixx's avatar
init  
jixx committed
940
941
        )

jixx's avatar
jixx committed
942
943
944
945
946
    def prepare_for_prefill(self):
        # Prepare values if we need to continue prefilling
        # Speculation must be ignored while we prefill even with chunking
        # it simplifies everything
        assert self.speculative_ids is None
jixx's avatar
init  
jixx committed
947

jixx's avatar
jixx committed
948
        device = self.block_tables_tensor.device
jixx's avatar
init  
jixx committed
949

jixx's avatar
jixx committed
950
951
952
953
954
955
        if isinstance(self.input_ids, list):
            if len(self) > 1:
                input_ids = np.concatenate(self.input_ids, dtype=np.int64)
            else:
                input_ids = self.input_ids[0]
            self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
jixx's avatar
init  
jixx committed
956

jixx's avatar
jixx committed
957
958
959
960
961
962
963
964
        self.input_lengths_tensor = torch.tensor(
            self.input_lengths, dtype=torch.int32, device=device
        )
        self.cu_seqlen_prefill = torch.nn.functional.pad(
            torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0)
        ).to(torch.int32)
        self.cache_lengths_tensor = torch.tensor(
            self.cache_lengths, dtype=torch.int32, device=device
jixx's avatar
init  
jixx committed
965
966
        )

jixx's avatar
jixx committed
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
        # If the device supports Triton, we can use a fused kernel
        if has_triton():
            self.position_ids = torch.empty(
                len(self.input_ids), dtype=torch.int32, device=device
            )
            self.slot_indices = torch.empty(
                len(self.input_ids), dtype=torch.int64, device=device
            )
            cu_slots_gpu = self.cu_slots.to(device)

            prepare_position_slot_ids(
                self.max_input_length,
                self.cache_lengths_tensor,
                self.cu_seqlen_prefill,
                cu_slots_gpu,
                self.position_ids,
                self.slot_indices,
            )
jixx's avatar
init  
jixx committed
985

jixx's avatar
jixx committed
986
987
988
989
990
991
992
        sliding_window = get_sliding_windows()
        position_ids = []
        slot_indices = []
        prefill_cache_indices = []
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_cu_outlens = [0]
jixx's avatar
init  
jixx committed
993

jixx's avatar
jixx committed
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
        # Cumulative length
        cumulative_length = 0
        cumulative_slot_tokens = 0
        prefill_out_cumulative_length = 0

        adapter_indices_list = []
        adapter_set = set()

        for i, (
            r,
            cache_length,
            input_length,
            prompt_length,
            request_prefilling,
            blocks,
        ) in enumerate(
            zip(
                self.requests,
                self.cache_lengths,
                self.input_lengths,
                self.prompt_lengths,
                self.prefilling_mask,
                self.block_tables,
            )
        ):
            next_chunk_length = input_length

            if not has_triton():
                # Position ids
                request_position_ids = torch.arange(
                    cache_length, cache_length + input_length, dtype=torch.int32
                )
                position_ids.append(request_position_ids)

                if not r.slots:
                    request_slots = [
                        s
                        for b in blocks
                        for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
                    ]
                else:
                    request_slots = r.slots

                request_slot_indices = torch.arange(
                    cache_length + cumulative_slot_tokens,
                    cache_length + cumulative_slot_tokens + input_length,
                    dtype=torch.int64,
                )

                slot_indices.append(request_slot_indices)

                # Update
                cumulative_slot_tokens += len(request_slots)

            # Create tensor to slice into the kv tensor in prefill
            if sliding_window is not None:
                request_prefill_cache_indices = torch.arange(
                    cumulative_length + max(0, input_length - sliding_window),
                    cumulative_length + input_length,
                    dtype=torch.int64,
                )

            # Prefill logprobs is ignored if the request is done prefilling
            prefill_logprobs = r.prefill_logprobs and request_prefilling

            all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs

            if prefill_logprobs:
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

            if sliding_window is not None:
                prefill_cache_indices.append(request_prefill_cache_indices)

            ADAPTER_TO_INDEX = get_adapter_to_index()
            if ADAPTER_TO_INDEX:
                adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
                adapter_indices_list.append(
                    torch.full((next_chunk_length,), adapter_index)
                )
                adapter_set.add(adapter_index)

            # Update
            cumulative_length += next_chunk_length

        if not all_prefill_logprobs and not no_prefill_logprobs:
            prefill_head_indices = []
            prefill_next_token_indices = []

            # Cumulative length
            cumulative_length = 0
            prefill_out_cumulative_length = 0

            for i, (
                r,
                input_length,
                request_prefilling,
            ) in enumerate(
                zip(
                    self.requests,
                    self.input_lengths,
                    self.prefilling_mask,
                )
            ):
                # Prefill logprobs is ignored if the request is done prefilling
                prefill_logprobs = r.prefill_logprobs and request_prefilling

                if prefill_logprobs:
                    prefill_head_indices.append(
                        torch.arange(
                            cumulative_length,
                            cumulative_length + input_length,
                            dtype=torch.int64,
                        )
                    )
                    prefill_next_token_indices.append(
                        prefill_out_cumulative_length + input_length - 1
                    )
                    prefill_out_cumulative_length += input_length
                else:
                    prefill_head_indices.append(
                        torch.tensor(
                            [cumulative_length + input_length - 1],
                            dtype=torch.int64,
                        )
                    )
                    prefill_next_token_indices.append(prefill_out_cumulative_length)
                    prefill_out_cumulative_length += 1

                # Update
                cumulative_length += input_length

        if len(self) > 1:
            if position_ids:
                position_ids = torch.cat(position_ids)
            if slot_indices:
                slot_indices = torch.cat(slot_indices)
            if sliding_window is not None:
                prefill_cache_indices = torch.cat(prefill_cache_indices)
        else:
            if position_ids:
                position_ids = position_ids[0]
            if slot_indices:
                slot_indices = slot_indices[0]
            if sliding_window is not None:
                prefill_cache_indices = prefill_cache_indices[0]

        if not has_triton():
            self.position_ids = position_ids.to(device)
            self.slot_indices = slot_indices.to(device)

        self.prefill_cu_outlens = prefill_cu_outlens
        self.prefill_cache_indices = (
            prefill_cache_indices.to(device) if sliding_window is not None else None
        )

        if all_prefill_logprobs:
            prefill_head_indices = None
            prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1
        elif no_prefill_logprobs:
            prefill_head_indices = self.cu_seqlen_prefill[1:] - 1
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.cat(prefill_head_indices).to(device)
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )

        self.prefill_head_indices = prefill_head_indices
        self.prefill_next_token_indices = prefill_next_token_indices

        if adapter_set:
            adapter_indices = torch.cat(adapter_indices_list).to(
                dtype=torch.int64, device=device
            )
            adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
        else:
            adapter_indices = torch.zeros_like(self.input_ids)
            adapter_segments = [0, len(adapter_indices)]
            adapter_segment_indices = [len(adapter_indices) - 1]

        adapter_segments = torch.tensor(
            adapter_segments, dtype=torch.int32, device=device
        )

        self.adapter_meta = AdapterBatchMetadata(
            adapter_indices=adapter_indices,
            adapter_set=adapter_set,
            adapter_segments=adapter_segments,
            segment_indices=adapter_segment_indices,
        )

    def __len__(self):
        return len(self.requests)


ADAPTER_LAYERS = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}


class FlashCausalLM(Model):
    def __init__(
        self,
        model_id: str,
        model_class,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        speculator: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
        lora_adapter_ids: Optional[list] = [],
        tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
        config_class: PreTrainedTokenizerBase = AutoConfig,
        default_dtype=torch.float16,
        aliases=None,
        # Used for Santacoder override of config
        num_kv_heads: Optional[int] = None,
        # Deepseek V2 uses different QK and V dims.
        head_size: Optional[int] = None,
        skip_special_tokens: bool = True,
        kv_cache_dtype: Optional[torch.dtype] = None,
        support_chunking: bool = True,
    ):
        self.quantize = quantize
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = default_dtype if dtype is None else dtype
        elif SYSTEM == "ipex":
            if hasattr(torch, "xpu") and torch.xpu.is_available():
                device = torch.device(f"xpu:{rank}")
                dtype = default_dtype if dtype is None else dtype
            else:
                device = torch.device("cpu")
                dtype = torch.bfloat16 if dtype is None else dtype
                init_cpu_threads_env(rank_id=rank, world_size=world_size)
        else:
            raise NotImplementedError(f"{model_class} is only available on GPU")

        tokenizer = tokenizer_class.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )
        try:
            generation_config = GenerationConfig.from_pretrained(
                model_id, revision=revision, trust_remote_code=trust_remote_code
            )
            if isinstance(generation_config.eos_token_id, (list, set)):
                # TODO Huge hack
                tokenizer._eos_token_ids = set(generation_config.eos_token_id)
        except Exception:
            pass

        config = config_class.from_pretrained(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        config.quantize = quantize
        config.speculator = speculator

        torch.distributed.barrier(group=self.process_group)

        weights_loader = get_loader(quantize, model_id, revision)
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(
            filenames,
            device,
            dtype,
            process_group=self.process_group,
            aliases=aliases,
            weights_loader=weights_loader,
        )

        prefix = ""
        model = model_class(prefix, config, weights)
        torch.distributed.barrier(group=self.process_group)

        # VLM models define the config we care about in their text_config
        text_config = getattr(config, "text_config", None)
        if text_config is not None:
            config = text_config

        if getattr(config, "sliding_window", None) is not None:
            set_sliding_window(config.sliding_window)
        else:
            config.sliding_window = None

        self.num_layers = config.num_hidden_layers
        self.num_heads = config.num_attention_heads // self.process_group.size()
        # Validation is done in the model itself
        if num_kv_heads is None:
            num_kv_heads = getattr(config, "num_key_value_heads", None)
            # GPT-2 workaround
            if num_kv_heads is None:
                num_kv_heads = getattr(config, "n_head", None)
        if num_kv_heads is None:
            raise ValueError("Cannot get the number of key/value heads")
        self.num_kv_heads = (
            num_kv_heads // self.process_group.size()
            if num_kv_heads > 1
            else num_kv_heads
        )
        assert self.num_kv_heads > 0

        if head_size is None:
            # Some models use GQA and different sizes for o_proj
            # and q_proj, that allows for that.
            if hasattr(config, "head_dim"):
                self.head_size = config.head_dim
            else:
                self.head_size = config.hidden_size // config.num_attention_heads
        else:
            self.head_size = head_size

        self.cuda_graphs = {}
        self.kv_cache = []
        self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype

        if ATTENTION == "flashinfer":
            from text_generation_server.layers.attention.flashinfer import (
                create_prefill_state,
                create_decode_state,
                create_prefill_with_paged_kv_state,
            )

            self.prefill_state = create_prefill_state(device=device)
            self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
                device=device
            )

            self.decode_state = create_decode_state(
                device=device,
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
            )

        super().__init__(
            model_id=model_id,
            model=model,
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
            rank=rank,
            world_size=world_size,
            sliding_window=config.sliding_window,
            support_chunking=support_chunking,
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

    def max_past(self) -> int:
        return getattr(self.model, "max_past", None)

    def init_kv_cache(
        self,
        num_blocks: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
jixx's avatar
init  
jixx committed
1370
1371
1372
1373
1374
        dtype: torch.dtype,
        device: torch.device,
    ):
        self.kv_cache = []
        empty_cache()
jixx's avatar
jixx committed
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
        self.kv_cache = [
            KVCache(
                num_blocks=num_blocks,
                num_heads=num_heads,
                head_size=head_size,
                dtype=dtype,
                device=device,
            )
            for _ in range(num_layers)
        ]
jixx's avatar
init  
jixx committed
1385
1386
1387
1388
1389

    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
        slots = torch.arange(bs, dtype=torch.int64, device=self.device)
jixx's avatar
jixx committed
1390
1391
1392
1393
        input_lengths = [max_s] * bs
        cache_lengths = [0] * bs
        input_lengths_tensor = (
            torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
jixx's avatar
init  
jixx committed
1394
        )
jixx's avatar
jixx committed
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
        cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
        block_tables = torch.arange(
            max_bt, dtype=torch.int32, device=self.device
        ).repeat(bs)
        block_tables = block_tables.reshape((bs, max_bt))

        if ATTENTION == "flashinfer":
            block_tables = block_tables_to_ragged(
                block_tables=block_tables,
                input_lengths=input_lengths,
                cache_lengths=cache_lengths,
                input_lengths_tensor=input_lengths_tensor,
                cache_lengths_tensor=cache_lengths_tensor,
                max_current_length=max_s,
            )
            from text_generation_server.layers.attention.flashinfer import (
                create_decode_state_cuda_graphs,
            )
jixx's avatar
init  
jixx committed
1413

jixx's avatar
jixx committed
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
            block_tables_ptr = torch.zeros(
                bs + 1, dtype=torch.int32, device=self.device
            )
            last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
            state = create_decode_state_cuda_graphs(
                device=input_ids.device,
                block_tables=block_tables,
                block_tables_ptr=block_tables_ptr,
                last_page_len=last_page_len,
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
            )
        else:
            state = None

        graph = torch.cuda.CUDAGraph()
jixx's avatar
init  
jixx committed
1430
1431
1432
1433
1434
1435
        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "kv_cache": self.kv_cache,
            "block_tables": block_tables,
            "slots": slots,
jixx's avatar
jixx committed
1436
1437
1438
1439
            "input_lengths": input_lengths_tensor,
            "cache_lengths": cache_lengths_tensor,
            "state": state,
            "graph": graph,
jixx's avatar
init  
jixx committed
1440
1441
1442
1443
        }

        torch.cuda.synchronize()
        # Run once outside to warmup
jixx's avatar
jixx committed
1444
        with self._forward_context(
jixx's avatar
init  
jixx committed
1445
            block_tables=block_tables,
jixx's avatar
jixx committed
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
            cu_seqlen_prefill=None,
            input_lengths_tensor=input_lengths_tensor,
            state=state,
            cache_lengths_tensor=cache_lengths_tensor,
        ):
            seqlen = Seqlen(
                input_lengths=input_lengths_tensor,
                cache_lengths=cache_lengths_tensor,
                cu_seqlen_q=None,
                max_q=1,
                max_k=max_s,
            )
            self.model.forward(
jixx's avatar
init  
jixx committed
1459
1460
1461
1462
1463
1464
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
                kv_cache=self.kv_cache,
                block_tables=block_tables,
                slots=slots,
jixx's avatar
jixx committed
1465
                seqlen=seqlen,
jixx's avatar
init  
jixx committed
1466
1467
1468
1469
                max_s=max_s,
                prefill_cache_indices=None,
                lm_head_indices=None,
            )
jixx's avatar
jixx committed
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
            del seqlen

            torch.cuda.synchronize()

            with torch.cuda.graph(graph, pool=MEM_POOL):
                seqlen = Seqlen(
                    input_lengths=input_lengths_tensor,
                    cache_lengths=cache_lengths_tensor,
                    cu_seqlen_q=None,
                    max_q=1,
                    max_k=max_s,
                )
                logits, speculative_logits = self.model.forward(
                    input_ids=input_ids,
                    position_ids=position_ids,
                    cu_seqlen_prefill=None,
                    kv_cache=self.kv_cache,
                    block_tables=block_tables,
                    slots=slots,
                    seqlen=seqlen,
                    max_s=max_s,
                    prefill_cache_indices=None,
                    lm_head_indices=None,
                )
                self.cuda_graphs[bs]["logits"] = logits
                self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
jixx's avatar
init  
jixx committed
1496
1497
1498
1499
        torch.cuda.synchronize()

    def warmup(self, batch: FlashCausalLMBatch):
        # The warmup batch is the biggest batch we could ever receive
jixx's avatar
jixx committed
1500
        self.kv_cache = []
jixx's avatar
init  
jixx committed
1501
1502
1503
1504
1505
1506
1507
1508
        empty_cache()

        try:
            self.init_kv_cache(
                batch.num_blocks,
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
jixx's avatar
jixx committed
1509
                self.kv_cache_dtype,
jixx's avatar
init  
jixx committed
1510
1511
1512
1513
1514
                self.device,
            )
            max_bt = batch.max_blocks
            max_s = max_bt * BLOCK_SIZE

jixx's avatar
jixx committed
1515
1516
1517
            if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
                # torch.cuda.tunable.tuning_enable(False)
                pass
jixx's avatar
init  
jixx committed
1518
1519
1520
            _, batch, _ = self.generate_token(batch)
        except torch.cuda.OutOfMemoryError as e:
            raise RuntimeError(
jixx's avatar
jixx committed
1521
                f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. "
jixx's avatar
init  
jixx committed
1522
1523
1524
1525
1526
1527
1528
                f"You need to decrease `--max-batch-prefill-tokens`"
            ) from e

        synchronize(self.device)

        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
jixx's avatar
jixx committed
1529
        dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
jixx's avatar
init  
jixx committed
1530
1531
1532
1533
1534
1535
1536
1537
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

        free_memory = get_free_memory(self.device, MEMORY_FRACTION)
        batch_num_blocks = batch.num_blocks if batch is not None else 0

        num_blocks = (
            # Leave 5% for some wiggle room
jixx's avatar
jixx committed
1538
            int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size)
jixx's avatar
init  
jixx committed
1539
1540
1541
1542
            # Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
            + batch_num_blocks
        )

jixx's avatar
jixx committed
1543
1544
        log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")

jixx's avatar
init  
jixx committed
1545
1546
1547
1548
1549
1550
1551
        del batch

        self.init_kv_cache(
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
jixx's avatar
jixx committed
1552
            self.kv_cache_dtype,
jixx's avatar
init  
jixx committed
1553
1554
1555
1556
1557
1558
1559
1560
            self.device,
        )

        if SYSTEM == "rocm":
            if (
                os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
                or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
            ):
jixx's avatar
jixx committed
1561
                # torch.cuda.tunable.enable()
jixx's avatar
init  
jixx committed
1562
1563

                if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
jixx's avatar
jixx committed
1564
1565
                    # torch.cuda.tunable.tuning_enable(True)
                    pass
jixx's avatar
init  
jixx committed
1566
1567
1568
1569
1570
1571
1572
1573
1574

                if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None:
                    tuning_sequences = [
                        int(val)
                        for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
                    ]
                elif CUDA_GRAPHS is not None:
                    tuning_sequences = CUDA_GRAPHS
                else:
jixx's avatar
jixx committed
1575
                    tuning_sequences = [1, 2, 3, 4, 5, 6, 7]
jixx's avatar
init  
jixx committed
1576
1577
1578

                tunableop_filepath = os.path.join(
                    HUGGINGFACE_HUB_CACHE,
jixx's avatar
jixx committed
1579
                    f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
jixx's avatar
init  
jixx committed
1580
1581
                )

jixx's avatar
jixx committed
1582
1583
1584
                log_master(
                    logger.info,
                    f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
jixx's avatar
init  
jixx committed
1585
1586
                )

jixx's avatar
jixx committed
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
                # torch.cuda.tunable.set_filename(
                #     tunableop_filepath, insert_device_ordinal=False
                # )

                # if os.path.isfile(tunableop_filepath):
                #     log_master(
                #         logger.info,
                #         f"The file {tunableop_filepath} already exists and will be reused.",
                #     )
                #     torch.cuda.tunable.read_file(tunableop_filepath)
jixx's avatar
init  
jixx committed
1597
1598
1599

                os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)

jixx's avatar
jixx committed
1600
1601
1602
1603
1604
1605
                # for seqlen in tuning_sequences:
                #     log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
                #     self.tunableop_warmup(seqlen)
                #     torch.cuda.tunable.write_file(tunableop_filepath)
                # if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
                #     torch.cuda.tunable.tuning_enable(False)
jixx's avatar
init  
jixx committed
1606
            else:
jixx's avatar
jixx committed
1607
1608
1609
                log_master(
                    logger.info,
                    "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
jixx's avatar
init  
jixx committed
1610
1611
1612
1613
                )

        if CUDA_GRAPHS:
            try:
jixx's avatar
jixx committed
1614
1615
1616
                log_master(
                    logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
                )
jixx's avatar
init  
jixx committed
1617
1618
1619
1620
1621
                # Warmup cuda graphs
                for bs in CUDA_GRAPHS:
                    if self.speculate is None or self.speculate + 1 <= bs:
                        self.cuda_graph_warmup(bs, max_s, max_bt)
            except torch.cuda.OutOfMemoryError:
jixx's avatar
jixx committed
1622
                logger.exception("Decode cuda graph warmup failed")
jixx's avatar
init  
jixx committed
1623
        else:
jixx's avatar
jixx committed
1624
1625
1626
            log_master(
                logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
            )
jixx's avatar
init  
jixx committed
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636

        return int(num_blocks * BLOCK_SIZE)

    def tunableop_warmup(self, seqlen: int):
        input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
        slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)

        # Dummy value, some models (starcoder2) don't accept `None`.
        input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
jixx's avatar
jixx committed
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
        cache_lengths_tensor = torch.zeros(
            seqlen, dtype=torch.int32, device=self.device
        )
        cu_seqlen_prefill = torch.tensor(
            [0, seqlen], device=self.device, dtype=torch.int32
        )
        max_s = seqlen
        seqlen = Seqlen(
            input_lengths=input_lengths,
            cache_lengths=cache_lengths_tensor,
            cu_seqlen_q=cu_seqlen_prefill,
            max_q=1,
            max_k=seqlen,
        )
jixx's avatar
init  
jixx committed
1651
1652
1653
1654
1655

        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
jixx's avatar
jixx committed
1656
            cu_seqlen_prefill=cu_seqlen_prefill,
jixx's avatar
init  
jixx committed
1657
1658
            kv_cache=self.kv_cache,
            block_tables=None,
jixx's avatar
jixx committed
1659
            seqlen=seqlen,
jixx's avatar
init  
jixx committed
1660
            slots=slots,
jixx's avatar
jixx committed
1661
            max_s=max_s,
jixx's avatar
init  
jixx committed
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
            lm_head_indices=None,
            prefill_cache_indices=None,
        )

    def forward(
        self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # Model Forward
        if batch.speculative_ids is not None:
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = self.kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
jixx's avatar
jixx committed
1678
            max_s = batch.max_current_length
jixx's avatar
init  
jixx committed
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
            lm_head_indices = batch.prefill_head_indices

            speculative_ids = batch.speculative_ids

            B, speculative_length = speculative_ids.shape
            new_length = speculative_length + 1
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
jixx's avatar
jixx committed
1697
1698
1699
            cache_lengths_tensor = (
                batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
            ).reshape(-1)
jixx's avatar
init  
jixx committed
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719

            # Add Copy the block tables for all members
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = self.kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
jixx's avatar
jixx committed
1720
1721
            cache_lengths_tensor = batch.cache_lengths_tensor
            max_s = batch.max_current_length
jixx's avatar
init  
jixx committed
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
            lm_head_indices = batch.prefill_head_indices

        if cu_seqlen_prefill is None and self.max_past() is not None:
            # In decode, not prefill, we're actually overwriting the KV-cache
            # in a circular buffer mode.
            # This makes sure the max_s for the decode pass is correct.
            max_s = min(self.max_past(), max_s)

        bs = input_ids.shape[0]
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None

        if cu_seqlen_prefill is not None or cuda_graph is None:
jixx's avatar
jixx committed
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
            if ATTENTION == "flashinfer":
                block_tables = block_tables_to_ragged(
                    block_tables=block_tables,
                    input_lengths=batch.input_lengths,
                    cache_lengths=batch.cache_lengths,
                    input_lengths_tensor=batch.input_lengths_tensor,
                    cache_lengths_tensor=batch.cache_lengths_tensor,
                    max_current_length=batch.max_current_length,
                )
            with self._forward_context(
jixx's avatar
init  
jixx committed
1749
                block_tables=block_tables,
jixx's avatar
jixx committed
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
                cu_seqlen_prefill=cu_seqlen_prefill,
                input_lengths_tensor=input_lengths,
                cache_lengths_tensor=cache_lengths_tensor,
            ):
                seqlen = Seqlen(
                    input_lengths=input_lengths,
                    cache_lengths=cache_lengths_tensor,
                    cu_seqlen_q=cu_seqlen_prefill,
                    max_q=batch.max_input_length,
                    max_k=batch.max_current_length,
                )
                logits, speculative_logits = self.model.forward(
                    input_ids=input_ids,
                    position_ids=position_ids,
                    cu_seqlen_prefill=cu_seqlen_prefill,
                    kv_cache=kv_cache,
                    block_tables=block_tables,
                    slots=slots,
                    seqlen=seqlen,
                    max_s=max_s,
                    prefill_cache_indices=batch.prefill_cache_indices,
                    lm_head_indices=lm_head_indices,
                    adapter_data=adapter_data,
                )
                if batch.prefill_cache_indices is not None:
                    batch.prefill_cache_indices = None
                return logits, speculative_logits
jixx's avatar
init  
jixx committed
1777
1778
1779
1780
1781

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
jixx's avatar
jixx committed
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
        if ATTENTION == "flashinfer":
            block_tables = block_tables_to_ragged(
                block_tables=block_tables,
                input_lengths=batch.input_lengths,
                cache_lengths=batch.cache_lengths,
                input_lengths_tensor=batch.input_lengths_tensor,
                cache_lengths_tensor=batch.cache_lengths_tensor,
                max_current_length=batch.max_current_length,
            )
            # assert block_tables.shape[0] >= slots.shape[0]
            cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
        else:
            cuda_graph["block_tables"][
                : block_tables.shape[0], : block_tables.shape[1]
            ] = block_tables

        # XXX: This is working only because block 0 is reserved for the healthcheck
        # so it doesn't matter if we override it with bogus values.
        cuda_graph["slots"].fill_(0)
jixx's avatar
init  
jixx committed
1801
1802
1803
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
jixx's avatar
jixx committed
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
        cuda_graph["cache_lengths"].zero_()
        cuda_graph["cache_lengths"][
            : cache_lengths_tensor.shape[0]
        ] = cache_lengths_tensor

        with self._forward_context(
            block_tables=cuda_graph["block_tables"],
            cu_seqlen_prefill=None,
            input_lengths_tensor=cuda_graph["input_lengths"],
            cache_lengths_tensor=cuda_graph["cache_lengths"],
            state=cuda_graph["state"],
        ):
            # Replay the graph
            cuda_graph["graph"].replay()
jixx's avatar
init  
jixx committed
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832

        # Slice output to the correct shape
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
jixx's avatar
jixx committed
1833
1834
1835
1836
        prefill = batch.prefilling
        if prefill:
            batch.prepare_for_prefill()

jixx's avatar
init  
jixx committed
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
        prefill_logprobs = batch.prefill_next_token_indices is not None

        # Update adapter indices for speculative tokens (if present)
        adapter_meta = batch.adapter_meta
        if batch.speculative_ids is not None:
            B, speculative_length = batch.speculative_ids.shape
            new_length = speculative_length + 1
            adapter_indices = (
                adapter_meta.adapter_indices.unsqueeze(-1)
                .expand(B, new_length)
                .reshape(-1)
            )
            adapter_segments = adapter_meta.adapter_segments * new_length
            adapter_meta = AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_meta.adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_meta.segment_indices,
            )

        # Assign pointers to adapter weights
        # TODO(travis): don't update this if indices haven't changed
        adapter_data = AdapterBatchData.from_meta(
            adapter_meta,
            self.layer_to_adapter_weights,
            prefill,
            batch.prefill_head_indices,
        )

        out, speculative_logits = self.forward(batch, adapter_data)

        if prefill:
            next_token_logits = (
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
            )
            if speculative_logits is not None:
                speculative_logits = (
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
                )
jixx's avatar
jixx committed
1878
1879
1880
1881
            if len(batch) > 1 and prefill_logprobs:
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
jixx's avatar
init  
jixx committed
1882
        else:
jixx's avatar
jixx committed
1883
            prefill_logprobs = None
jixx's avatar
init  
jixx committed
1884
            next_token_logits = out
jixx's avatar
jixx committed
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931

        finished_prefilling = True
        next_chunk_lengths = []
        current_prefilling_mask = batch.prefilling_mask
        if prefill:
            if get_support_chunking():
                next_prefilling_mask = []
                # Budget in tokens for the next batch
                # We remove (len(batch) - 1) to always have enough space for at least a single decode
                # for the remaining requests -1 because the first request does not need to be removed from the budget
                # (ex: you have one request in the batch, you want it to take the full budget not budget -1)
                batch_budget = get_max_prefill_tokens() - (len(batch) - 1)
                # We reverse to prioritize older requests
                # zip() is not reversible so reverse the underlying lists instead
                for cache_length, input_length, prompt_length in zip(
                    reversed(batch.cache_lengths),
                    reversed(batch.input_lengths),
                    reversed(batch.prompt_lengths),
                ):
                    remaining_prefill_tokens = max(
                        prompt_length - cache_length - input_length, 0
                    )
                    if remaining_prefill_tokens > 0:
                        next_chunk_length = max(
                            min(remaining_prefill_tokens, batch_budget), 1
                        )
                        batch_budget -= next_chunk_length
                        finished_prefilling = False
                        next_prefilling_mask.append(True)
                    else:
                        # FIXME: use true number of accepted tokens instead of 1
                        # Since speculation will be turned off, this is always true
                        next_chunk_length = 1
                        next_prefilling_mask.append(False)
                    next_chunk_lengths.append(next_chunk_length)

                # Reverse back the obtained values²
                next_chunk_lengths.reverse()
                next_prefilling_mask.reverse()
            else:
                # The model does not support chunking
                # We know we only do a single prefill
                finished_prefilling = True
                next_prefilling_mask = [False] * len(batch)

            batch.prefilling = not finished_prefilling
            batch.prefilling_mask = next_prefilling_mask
jixx's avatar
init  
jixx committed
1932
1933
1934
1935
1936
1937
1938
1939
1940

        speculate = get_speculate()
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
jixx's avatar
jixx committed
1941
            batch.all_input_ids_tensor[:, : batch.max_current_length],
jixx's avatar
init  
jixx committed
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
            next_token_logits,
            speculate,
            batch.speculative_ids,
            speculative_logits,
        )

        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
        )

jixx's avatar
jixx committed
1952
1953
1954
1955
1956
1957
1958
1959
1960
        # Since we are done prefilling, all the tensors that were concatenating values for all the requests
        # instantly become of shape [BATCH_SIZE]
        if prefill and finished_prefilling:
            indices = batch.cu_seqlen_prefill[1:] - 1
            batch.position_ids = batch.position_ids[indices]
            batch.slot_indices = batch.slot_indices[indices]
            batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
                indices
            ]
jixx's avatar
init  
jixx committed
1961
1962

        # Zipped iterator
jixx's avatar
jixx committed
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
        iterator = zip(
            batch.requests,
            batch.prompt_lengths,
            batch.cache_lengths,
            batch.input_lengths,
            batch.all_input_ids,
            accepted_ids,
            current_prefilling_mask,
            batch.prefilling_mask,
        )
jixx's avatar
init  
jixx committed
1973
1974
1975
1976
1977
1978

        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

        # For each member of the batch
jixx's avatar
jixx committed
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
        # Cumulative length
        cu_accepted_ids = torch.nn.functional.pad(
            torch.cumsum(accepted_ids, dim=0), (1, 0)
        )
        cumulative_length = 0
        for i, (
            request,
            prompt_length,
            cache_length,
            input_length,
            all_input_ids,
            n_accepted_ids,
            request_was_prefilling,
            request_is_prefilling,
        ) in enumerate(iterator):
            # Used to gather prefill logprobs
            # Copy batch.all_input_ids_tensor to prefill_token_indices
            if request.prefill_logprobs and request_was_prefilling:
jixx's avatar
init  
jixx committed
1997
1998
1999
2000
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]

jixx's avatar
jixx committed
2001
2002
2003
2004
                # Logprobs generated by the model are for the next token
                # So we need to translate the id tensor by 1
                ids = batch.all_input_ids_tensor[
                    i, cache_length + 1 : cache_length + input_length + 1
jixx's avatar
init  
jixx committed
2005
                ]
jixx's avatar
jixx committed
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
                if len(batch) > 1:
                    prefill_tokens_indices[out_start_index:out_end_index] = ids
                else:
                    # Set prefill_tokens_indices to the correct slice
                    prefill_tokens_indices = ids

            # If the device does not support triton, we copy one by one
            if not request_is_prefilling and not has_triton():
                # Only save tokens if we are done prefilling for this request
                batch.all_input_ids_tensor[
                    i,
                    batch.cache_lengths_tensor[i]
                    + batch.input_lengths[i] : batch.cache_lengths_tensor[i]
                    + batch.input_lengths[i]
                    + accepted_ids[i],
                ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
jixx's avatar
init  
jixx committed
2022
2023
            cumulative_length += input_length

jixx's avatar
jixx committed
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
        # If the device support triton, we can use a fused kernel
        if has_triton():
            copy_next_input_ids_inplace(
                speculate + 1,
                batch.all_input_ids_tensor,
                batch.cache_lengths_tensor,
                batch.input_lengths_tensor,
                batch.prompt_lengths_tensor,
                next_input_ids,
                cu_accepted_ids,
            )

jixx's avatar
init  
jixx committed
2036
        # Update values
jixx's avatar
jixx committed
2037
2038
2039
2040
2041
2042
2043
2044
        # These values can be updated without a GPU -> CPU sync
        if not prefill or (prefill and finished_prefilling):
            batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
            batch.speculative_ids = speculative_ids
            batch.position_ids += accepted_ids
            batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
            batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
            batch.slot_indices += accepted_ids
jixx's avatar
init  
jixx committed
2045

jixx's avatar
jixx committed
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
        if prefill and prefill_logprobs:
            # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size))
            torch.log_softmax(out, -1, out=out)
            prefill_logprobs_tensor = out
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # Does a GPU <-> CPU sync internally
        if prefill and finished_prefilling:
jixx's avatar
init  
jixx committed
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
            # adjust segment lengths to account for all request lengths being 1 during decoding
            adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
            batch.adapter_meta.adapter_segments = torch.tensor(
                adapter_segments,
                dtype=torch.int32,
                device=batch.adapter_meta.adapter_segments.device,
            )

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
        next_token_ids = next_input_ids.tolist()
        accepted_ids = accepted_ids.tolist()
jixx's avatar
jixx committed
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110

        # Update values if we need to continue prefilling
        # This represents the `else` case of the `Update values` if above
        # but since this require the `next_token_ids` to be on CPU, it is better to do it here
        if prefill and not finished_prefilling:
            # Speculation must be ignored while we prefill even with chunking
            # it simplifies everything
            assert batch.speculative_ids is None

            all_postfix_ids = []
            for i, (
                request_prefilling,
                next_token_id,
                all_input_ids,
                cache_length,
                input_length,
                next_chunk_length,
            ) in enumerate(
                zip(
                    batch.prefilling_mask,
                    next_token_ids,
                    batch.all_input_ids,
                    batch.cache_lengths,
                    batch.input_lengths,
                    next_chunk_lengths,
                )
            ):
                if request_prefilling:
                    next_cache_length = cache_length + input_length
                    # Get new prompt IDs to prefill
                    postfix_ids = all_input_ids[
                        next_cache_length : next_cache_length + next_chunk_length
                    ]
                else:
                    # This request is done prefilling, the new id is the one selected the sampling method
                    postfix_ids = [next_token_id]

                all_postfix_ids.append(postfix_ids)

            batch.input_ids = all_postfix_ids

jixx's avatar
init  
jixx committed
2111
2112
        start_decode = time.time_ns()

jixx's avatar
jixx committed
2113
2114
2115
2116
        # Results
        generations: List[Generation] = []
        stopped = True

jixx's avatar
init  
jixx committed
2117
2118
2119
        # Zipped iterator
        iterator = zip(
            batch.requests,
jixx's avatar
jixx committed
2120
2121
            batch.prompt_lengths,
            batch.cache_lengths,
jixx's avatar
init  
jixx committed
2122
2123
2124
2125
2126
2127
2128
2129
            batch.input_lengths,
            batch.prefix_offsets,
            batch.read_offsets,
            batch.stopping_criterias,
            batch.all_input_ids,
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
            batch.top_n_tokens,
jixx's avatar
jixx committed
2130
2131
            current_prefilling_mask,
            batch.prefilling_mask,
jixx's avatar
init  
jixx committed
2132
2133
2134
2135
2136
            accepted_ids,
            batch_top_token_ids,
            batch_top_token_logprobs,
        )

jixx's avatar
jixx committed
2137
2138
        # Reset max_input_length
        batch.max_input_length = 0
jixx's avatar
init  
jixx committed
2139
2140
2141
2142
        # For each member of the batch
        index = 0
        for i, (
            request,
jixx's avatar
jixx committed
2143
2144
            prompt_length,
            cache_length,
jixx's avatar
init  
jixx committed
2145
2146
2147
2148
2149
2150
2151
2152
            input_length,
            prefix_offset,
            read_offset,
            stopping_criteria,
            all_input_ids,
            do_sample,
            seed,
            top_n_tokens,
jixx's avatar
jixx committed
2153
2154
            request_was_prefilling,
            request_is_prefilling,
jixx's avatar
init  
jixx committed
2155
2156
2157
2158
            n_accepted_ids,
            top_token_ids,
            top_token_logprobs,
        ) in enumerate(iterator):
jixx's avatar
jixx committed
2159
2160
2161
2162
2163
            # Compute logprobs first as, even though we might skip the token,
            # it can still be required to compute the logprobs
            # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
            # this state to be stable
            if request.id % self.world_size == self.rank:
jixx's avatar
init  
jixx committed
2164
                # Prefill
jixx's avatar
jixx committed
2165
                if request_was_prefilling and request.prefill_logprobs:
jixx's avatar
init  
jixx committed
2166
2167
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]
jixx's avatar
jixx committed
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
                    if not request_is_prefilling:
                        # The request is dones prefilling, meaning that we started generating new tokens
                        # The last logprob is a logprob for a generated token that was not part of the prompt
                        # We need to remove it
                        out_end_index -= 1

                    request_prefill_logprobs = prefill_logprobs[
                        out_start_index:out_end_index
                    ]
                    # Logprobs generated by the model are for the next token
                    # So we need to translate the id tensor by 1
                    prefill_token_ids = all_input_ids[
                        cache_length + 1 : cache_length + input_length + 1
jixx's avatar
init  
jixx committed
2181
                    ]
jixx's avatar
jixx committed
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193

                    past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]

                    if past_prefill_logprob_tokens is None:
                        # add nan for cached prompt tokens/first token
                        request_prefill_logprobs = [float("nan")] * (
                            cache_length + 1
                        ) + request_prefill_logprobs
                        prefill_token_ids = (
                            all_input_ids[: cache_length + 1] + prefill_token_ids
                        )

jixx's avatar
init  
jixx committed
2194
2195
2196
2197
2198
2199
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )

jixx's avatar
jixx committed
2200
                    prefill_logprob_tokens = Tokens(
jixx's avatar
init  
jixx committed
2201
2202
2203
2204
2205
                        prefill_token_ids,
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
                    )
jixx's avatar
jixx committed
2206
2207
2208
2209
2210
2211
                    if past_prefill_logprob_tokens is not None:
                        prefill_logprob_tokens = (
                            past_prefill_logprob_tokens + prefill_logprob_tokens
                        )

                    batch.prefill_logprob_tokens[i] = prefill_logprob_tokens
jixx's avatar
init  
jixx committed
2212
                else:
jixx's avatar
jixx committed
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
                    batch.prefill_logprob_tokens[i] = None

            # If it is, the tokens we decoded should be ignored
            if request_is_prefilling:
                # Make sure that we do not stop as even though this request did not create a token, it is still
                # processing
                stopped = False
                new_input_length = next_chunk_lengths[i]
                new_cache_length = cache_length + input_length
            else:
                new_input_length = 1
                new_cache_length = cache_length + input_length + n_accepted_ids - 1
                # Append next token to all tokens
                next_token_texts = []
                left = 0

                if n_accepted_ids > 1:
                    log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")

                current_stopped = False
                for j in range(index, index + n_accepted_ids):
                    # Generated token
                    next_token_id = next_token_ids[j]
                    all_input_ids.append(next_token_id)
                    next_token_text, prefix_offset, read_offset = self.decode_token(
                        all_input_ids,
                        prefix_offset,
                        read_offset,
                    )
                    next_token_texts.append(next_token_text)

                    stop, reason = stopping_criteria(
                        next_token_id,
                        next_token_text,
                    )

                    if stop:
                        left = index + n_accepted_ids - j - 1
                        current_stopped = True
                        break
                    else:
                        current_stopped = False
                stopped = stopped and current_stopped

                _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
                _next_token_logprobs = next_token_logprobs[
                    index : index + n_accepted_ids - left
                ]

                # Shard generations
                # All generations will be appended in the rust sharded client
                if request.id % self.world_size == self.rank:
                    if stop:
                        # Decode generated tokens
                        output_text, _, _ = self.decode_token(
                            all_input_ids,
                            prefix_offset=len(all_input_ids)
                            - stopping_criteria.current_tokens
                            - 1,
                            read_offset=len(all_input_ids)
                            - stopping_criteria.current_tokens,
                            skip_special_tokens=True,
jixx's avatar
init  
jixx committed
2275
                        )
jixx's avatar
jixx committed
2276
2277
2278
2279
2280
                        generated_text = GeneratedText(
                            output_text,
                            stopping_criteria.current_tokens,
                            reason,
                            seed if do_sample else None,
jixx's avatar
init  
jixx committed
2281
                        )
jixx's avatar
jixx committed
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
                    else:
                        generated_text = None

                    if top_n_tokens > 0:
                        all_top_tokens = []
                        for top_token_ids, top_token_logprobs in zip(
                            top_token_ids, top_token_logprobs
                        ):
                            toptoken_texts = self.tokenizer.batch_decode(
                                top_token_ids,
                                clean_up_tokenization_spaces=False,
                                skip_special_tokens=False,
                            )
                            special_toptokens = [
                                token_id in self.all_special_ids
                                for token_id in top_token_ids
                            ]
                            top_tokens = Tokens(
                                top_token_ids,
                                top_token_logprobs,
                                toptoken_texts,
                                special_toptokens,
                            )
                            all_top_tokens.append(top_tokens)
                        top_tokens = all_top_tokens
                    else:
                        top_tokens = None

                    generation = Generation(
                        request.id,
                        batch.prefill_logprob_tokens[i],
                        Tokens(
                            _next_token_ids,
                            _next_token_logprobs,
                            next_token_texts,
                            [nid in self.all_special_ids for nid in _next_token_ids],
                        ),
                        generated_text,
                        top_tokens,
                    )
jixx's avatar
init  
jixx committed
2322

jixx's avatar
jixx committed
2323
                    generations.append(generation)
jixx's avatar
init  
jixx committed
2324

jixx's avatar
jixx committed
2325
2326
2327
2328
2329
2330
2331
2332
                # accept each new token for this specific request since we may
                # have more than one new token per request with speculative decoding
                for next_token_id in _next_token_ids:
                    batch.next_token_chooser = (
                        batch.next_token_chooser.advance_grammar_single(
                            i, next_token_id
                        )
                    )
jixx's avatar
init  
jixx committed
2333
2334

            # Update values
jixx's avatar
jixx committed
2335
2336
2337
2338
2339
2340
2341
            index += n_accepted_ids
            batch.cache_lengths[i] = new_cache_length
            batch.max_input_length = max(batch.max_input_length, new_input_length)
            batch.input_lengths[i] = new_input_length
            current_length = new_cache_length + new_input_length
            batch.max_current_length = max(batch.max_current_length, current_length)

jixx's avatar
init  
jixx committed
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
            batch.all_input_ids[i] = all_input_ids

        if stopped:
            # No need to return a batch if we know that all requests stopped
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)

jixx's avatar
jixx committed
2352
2353
2354
2355
2356
2357
2358
        if prefill and finished_prefilling:
            # We do not need prefill tensors anymore
            batch.cu_seqlen_prefill = None
            batch.prefill_cache_indices = None
            batch.prefill_cu_outlens = None
            batch.prefill_head_indices = None
            batch.prefill_next_token_indices = None
jixx's avatar
init  
jixx committed
2359
2360
2361
2362

        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)
jixx's avatar
jixx committed
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409

    def _forward_context(
        self,
        *,
        block_tables: torch.Tensor,
        cu_seqlen_prefill: Optional[torch.Tensor],
        input_lengths_tensor: torch.Tensor,
        cache_lengths_tensor: torch.Tensor,
        state: Optional[Any] = None,
    ) -> ContextManager:
        if ATTENTION != "flashinfer":
            return nullcontext()

        from text_generation_server.layers.attention.flashinfer import (
            use_decode_state,
            use_prefill_with_paged_kv_state,
        )

        if cu_seqlen_prefill is not None:
            return use_prefill_with_paged_kv_state(
                state=(
                    state if state is not None else self.prefill_with_paged_kv_state
                ),
                block_tables=block_tables,
                cu_seqlens=cu_seqlen_prefill,
                input_lengths=input_lengths_tensor + cache_lengths_tensor,
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                page_size=BLOCK_SIZE,
                dtype=self.dtype,
                window_left=self.sliding_window,
            )
        else:
            assert input_lengths_tensor is not None
            return use_decode_state(
                state=state if state is not None else self.decode_state,
                input_lengths=input_lengths_tensor + cache_lengths_tensor,
                block_tables=block_tables,
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                page_size=BLOCK_SIZE,
                kv_cache_dtype=self.kv_cache_dtype,
                dtype=self.dtype,
                window_left=self.sliding_window,
            )