flash_causal_lm.py 96.2 KB
Newer Older
1
from contextlib import nullcontext
2
import math
3
import os
4
import time
5
6
7
import torch
import torch.distributed

8
9
import numpy as np

10
from loguru import logger
11
12
from dataclasses import dataclass
from opentelemetry import trace
13
14
15
16
17
18
from transformers import (
    PreTrainedTokenizerBase,
    AutoConfig,
    AutoTokenizer,
    GenerationConfig,
)
19
20
21
22
23
24
25
26
27
28
29
from typing import (
    Any,
    ContextManager,
    Iterable,
    Optional,
    Tuple,
    List,
    Type,
    Dict,
    Union,
)
fxmarty's avatar
fxmarty committed
30

drbh's avatar
drbh committed
31
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
fxmarty's avatar
fxmarty committed
32
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
Daniël de Kok's avatar
Daniël de Kok committed
33
from text_generation_server.utils.chunks import concat_text_chunks
Nicolas Patry's avatar
Nicolas Patry committed
34
from text_generation_server.utils.import_utils import SYSTEM
OlivierDehaene's avatar
OlivierDehaene committed
35
from text_generation_server.models import Model
36
from text_generation_server.utils.log import log_master
37
38
39
40
from text_generation_server.utils.prefill_chunking import (
    get_support_chunking,
    get_max_prefill_tokens,
)
41
from text_generation_server.utils.tokens import batch_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
42
from text_generation_server.utils.speculate import get_speculate
43
44
45
46
47
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
)
48
49
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
50
    Tokens,
51
52
53
54
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
Nicolas Patry's avatar
Nicolas Patry committed
55
56
from text_generation_server.models.globals import (
    MEM_POOL,
57
    ATTENTION,
58
    BLOCK_SIZE,
Nicolas Patry's avatar
Nicolas Patry committed
59
    CUDA_GRAPHS,
Nicolas Patry's avatar
Nicolas Patry committed
60
    REQUEST_LOGPROBS,
61
    TGI_WIGGLE_ROOM,
Nicolas Patry's avatar
Nicolas Patry committed
62
63
    get_adapter_to_index,
)
64
from text_generation_server.layers.attention import KVCache, Seqlen
65
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
66
from text_generation_server.utils.dist import MEMORY_FRACTION
67
from text_generation_server.utils.quantization import get_loader
drbh's avatar
drbh committed
68
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
69

Nicolas Patry's avatar
Nicolas Patry committed
70
from text_generation_server.utils.import_utils import (
Nicolas Patry's avatar
Nicolas Patry committed
71
72
73
    empty_cache,
    synchronize,
    get_free_memory,
Nicolas Patry's avatar
Nicolas Patry committed
74
)
75
76
77
78
79
80
81
82
from text_generation_server.models.metadata_kernels import (
    has_triton,
    copy_next_input_ids_inplace,
    block_tables_to_ragged,
    block_tables_to_padded,
    prepare_position_slot_ids,
    slots_filtering,
)
Nicolas Patry's avatar
Nicolas Patry committed
83

Nicolas Patry's avatar
Nicolas Patry committed
84
85
tracer = trace.get_tracer(__name__)

86
87
88
89
# Will be set in init
SLIDING_WINDOW: Optional[int] = None


90
91
92
93
def small_power_of_2(n: int):
    return 1 << ((n - 1).bit_length() - 1)


94
95
96
97
98
99
100
101
102
def set_sliding_window(sliding_window: int):
    global SLIDING_WINDOW
    SLIDING_WINDOW = sliding_window


def get_sliding_windows() -> int:
    global SLIDING_WINDOW
    return SLIDING_WINDOW

103

104
105
106
107
108
109
110
def init_cpu_threads_env(rank_id: int, world_size: int):
    import importlib.util

    if importlib.util.find_spec("numa") is not None:
        import numa
        import psutil

111
        nodes = numa.info.get_max_node() + 1
112
113
114
115
116
117
118
119
        rank_per_node = math.ceil(world_size / nodes)
        num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes)
        node_id = int(rank_id / rank_per_node)
        rank_offset_per_node = rank_id % rank_per_node
        if os.getenv("OMP_NUM_THREADS") is None:
            num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1)
        else:
            num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS"))
120
121
        if len(numa.memory.get_membind_nodes()) == nodes:
            numa.memory.set_membind_nodes((node_id))
122
        torch.set_num_threads(num_cpus_per_rank)
123
        if len(numa.schedule.get_affinitive_cpus(0)) == psutil.cpu_count(logical=True):
124
            cpu_start = num_cpus_per_rank * rank_offset_per_node
125
            numa.schedule.run_on_cpus(
126
                0,
127
128
129
130
131
                *(
                    numa.info.node_to_cpus(node_id)[
                        cpu_start : cpu_start + num_cpus_per_rank
                    ]
                ),
132
            )
133
134
135
        logger.info(
            f"affinity={numa.schedule.get_affinitive_cpus(0)}, membind = {numa.memory.get_membind_nodes()}"
        )
136
137


138
139
140
141
@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
142
143
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
144
145

    # Decoder values
146
147
148
149
150
    # Can be a list for easy filtering
    # If `input_ids` is a list, it needs to be materialized to a tensor first
    input_ids: Union[torch.Tensor, List[List[int]]]
    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
    position_ids: Optional[torch.Tensor]
151
    speculative_ids: Optional[torch.Tensor]
152

153
154
    # Set when creating the batch
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
155
156
    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
    slot_indices: Optional[torch.Tensor]
157
158

    # list of length b of list of length s_i // block_size
159
    block_tables: List[List[int]]
160
    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
161
    block_tables_tensor: torch.Tensor
162
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
163
164
165
166
    slots: torch.Tensor
    # list of length b + 1  containing the cumulative sequence slot lengths of the sequences in the batch
    # used for filtering
    cu_slots: torch.Tensor
167

168
169
170
171
172
173
174
    max_input_length: int
    max_current_length: int

    # Whether this batch contains at least one request that is prefilling
    prefilling: bool
    # Whether each request is prefilling
    prefilling_mask: List[bool]
175

176
    # Prefill metadata tensors to efficiently compute logprobs
177
    # tensor of length b + 1  containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
178
179
180
181
182
    cu_seqlen_prefill: Optional[torch.Tensor]
    # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
    # as we only keep SLIDING_WINDOW values instead of the whole tensor
    prefill_cache_indices: Optional[torch.Tensor]
    # Will be set by `generate_token` and reset after each prefill forward
183
    prefill_head_indices: Optional[torch.Tensor]
184
    # Will be set by `generate_token` and reset after each prefill forward
185
    prefill_next_token_indices: Optional[torch.tensor]
186
    # Will be set by `generate_token` and reset after each prefill forward
187
    prefill_cu_outlens: Optional[List[int]]
188
189
    # Will be set by `generate_token` and reset after each prefill forward
    prefill_logprob_tokens: List[Optional[Tokens]]
Nicolas Patry's avatar
Nicolas Patry committed
190

191
192
    # All tokens
    all_input_ids: List[List[int]]
193
    all_input_ids_tensor: torch.Tensor
194
195
196

    # Lengths of all generations present in the batch
    input_lengths: List[int]
197
198
199
200
201
202
203
204
    # size [b], containing the number of blocks that can be retrieved from the cache
    cache_lengths: List[int]
    prompt_lengths: List[int]
    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
    input_lengths_tensor: Optional[torch.Tensor]
    cache_lengths_tensor: Optional[torch.Tensor]
    prompt_lengths_tensor: torch.Tensor

205
206
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
207
208

    # Generation helpers
209
    next_token_chooser: HeterogeneousNextTokenChooser
210
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
211
212
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
213

drbh's avatar
drbh committed
214
    # Adapter metadata for each request
215
216
    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
    adapter_meta: Optional[AdapterBatchMetadata]
drbh's avatar
drbh committed
217

218
    # Number of blocks in this batch
219
    num_blocks: int
220
221
    # Maximum number of blocks
    max_blocks: int
222

223
224
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
225
            id=self.batch_id,
226
            request_ids=[r.id for r in self.requests],
227
            size=len(self),
228
            max_tokens=self.num_blocks * BLOCK_SIZE,
229
230
231
232
233
            current_tokens=(
                sum([len(i) for i in self.input_ids])
                if isinstance(self.input_ids, list)
                else len(self.input_ids)
            ),
234
235
236
        )

    @classmethod
Daniël de Kok's avatar
Daniël de Kok committed
237
238
239
    def batch_tokenized_inputs(
        cls, requests: Iterable[generate_pb2.Request], tokenizer
    ):
240
241
242
        max_length = 0
        all_input_ids = []
        batch_size = 0
243
        for r in requests:
244
245
246
247
248
249
250
251
252
253
254
            batch_size += 1
            inputs = concat_text_chunks(r.input_chunks.chunks)
            input_ids = tokenizer(
                inputs,
                truncation=True,
                max_length=r.truncate,
                add_special_tokens=r.add_special_tokens,
            )["input_ids"]
            max_length = max(max_length, len(input_ids))
            all_input_ids.append(input_ids)
        return all_input_ids
255

drbh's avatar
drbh committed
256
257
258
259
260
261
262
263
264
    @classmethod
    def from_tokenized(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        batch_tokenized_inputs,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
265
        speculate = get_speculate()
266

267
        cache_lengths = []
268
        input_lengths = []
269
        prompt_lengths = []
270
271
        prefix_offsets = []
        read_offsets = []
272
        all_input_ids = []
273
        all_postfix_ids = []
274
        requests_idx_mapping = {}
275
276
        slots = []
        cu_slots = [0]
277

278
        next_token_chooser_parameters = []
279
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
280
        top_n_tokens = []
281

282
        num_blocks = 0
283
284
        max_input_length = 0
        max_current_length = 0
285
        max_length = 0
286
        max_blocks = 0
287

288
        cu_blocks = [0]
289
        block_tables = []
290
        block_tables_ragged = []
291

292
        # Parse batch
293
294
295
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
Nicolas Patry's avatar
Nicolas Patry committed
296
297
298
299
            ### XXX: This consumes so much memory on long requests
            ### Deactivating it by default seems like the best course.
            if not REQUEST_LOGPROBS:
                r.prefill_logprobs = False
300
301
302
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

303
304
305
306
            prompt_length = len(tokenized_input)
            prompt_lengths.append(prompt_length)

            cache_length = r.cache_len
Nicolas Patry's avatar
Nicolas Patry committed
307

308
            assert (
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
                cache_length <= prompt_length
            ), f"Prefix {cache_length} vs input {prompt_length}"
            if cache_length == prompt_length:
                assert False, "unreachable"

            # `chunk_len` is an optional field in the protobuf
            # It is only set if the model support chunking
            if r.HasField("chunk_len"):
                input_length = r.chunk_len

                if cache_length + input_length < prompt_length:
                    # FIXME: speculate is not supported for context chunking at the moment
                    assert speculate == 0
                    assert get_support_chunking()
                    assert input_length > 0

                postfix_ids = tokenized_input[
                    cache_length : cache_length + input_length
                ]
                assert (
                    len(postfix_ids) == input_length
                ), "Rust and Python tokenizers are not aligned"
            else:
                # Use all the remaining ids
                postfix_ids = tokenized_input[cache_length:]
                input_length = len(postfix_ids)

336
            input_lengths.append(input_length)
337

338
339
            prefix_offsets.append(prompt_length - 5)
            read_offsets.append(prompt_length)
340

341
            all_postfix_ids.append(postfix_ids)
342
            all_input_ids.append(tokenized_input)
343

344
            next_token_chooser_parameters.append(r.parameters)
345

346
347
348
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
349
            max_new_tokens = stopping_criteria.max_new_tokens
350
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
351
            top_n_tokens.append(r.top_n_tokens)
352

353
354
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
355
            speculative_length = get_speculate()
drbh's avatar
drbh committed
356
            speculative_length = 0 if speculative_length is None else speculative_length
Nicolas Patry's avatar
Nicolas Patry committed
357
358

            # Tokens that need to be mapped to blocks.
359
            block_tokens = prompt_length + max_new_tokens - 1 + speculative_length
360
361
362

            # blocks and slots can be empty (for example in warmup)
            if not r.blocks:
Nicolas Patry's avatar
Nicolas Patry committed
363
                needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
364
365
366
                request_blocks = [
                    b for b in range(num_blocks, num_blocks + needed_blocks)
                ]
367
368
369
370
371
                request_slots = [
                    s
                    for b in request_blocks
                    for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
                ]
372
373
            else:
                request_blocks = r.blocks
374
                request_slots = r.slots
375
376

            block_tables.append(request_blocks)
377
378
379
380
381
            block_tables_ragged.extend(request_blocks)
            cu_blocks.append(len(block_tables_ragged))

            slots.extend(request_slots)
            cu_slots.append(len(slots))
Nicolas Patry's avatar
Nicolas Patry committed
382

383
            cache_lengths.append(cache_length)
384
            num_blocks += len(request_blocks)
385

386
            # Update
387
            max_blocks = max(max_blocks, len(request_blocks))
388
389
            max_input_length = max(max_input_length, input_length)
            max_current_length = max(max_current_length, cache_length + input_length)
OlivierDehaene's avatar
OlivierDehaene committed
390
            max_length = max(
391
392
                max_length,
                prompt_length + max_new_tokens + speculative_length,
OlivierDehaene's avatar
OlivierDehaene committed
393
            )
394
395

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
396
            next_token_chooser_parameters, dtype, device, tokenizer
397
398
399
400
401
402
403
404
        )

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
405

406
407
408
409
410
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

Nicolas Patry's avatar
Nicolas Patry committed
411
412
413
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
414

415
416
        block_tables_ragged = torch.tensor(
            block_tables_ragged, device=device, dtype=torch.int32
417
        )
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64)
        block_tables_tensor = torch.empty(
            (len(block_tables), max_blocks),
            device=device,
            dtype=torch.int32,
        )

        # If the device supports Triton, we can use a fused kernel
        if has_triton():
            block_tables_to_padded(
                max_blocks, cu_blocks, block_tables_tensor, block_tables_ragged
            )
        else:
            for i, request_blocks in enumerate(block_tables):
                block_tables_tensor[i, : len(request_blocks)] = torch.tensor(
                    request_blocks
                )

436
437
438
        prompt_lengths_tensor = torch.tensor(
            prompt_lengths, dtype=torch.int32, device=device
        )
439

440
441
442
        slots = torch.tensor(slots, dtype=torch.int64, device=device)
        cu_slots = torch.tensor(cu_slots, dtype=torch.int64)

443
444
445
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
446
            requests_idx_mapping=requests_idx_mapping,
447
            input_ids=all_postfix_ids,
448
449
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
450
451
452
453
454
455
            cache_lengths=cache_lengths,
            max_input_length=max_input_length,
            max_current_length=max_current_length,
            prefilling=True,
            prefilling_mask=[True] * len(pb.requests),
            prefill_logprob_tokens=[None] * len(pb.requests),
456
            input_lengths=input_lengths,
457
            prompt_lengths=prompt_lengths,
458
459
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
460
            all_input_ids=all_input_ids,
461
462
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
463
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
464
465
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
466
            num_blocks=num_blocks,
467
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
468
            speculative_ids=None,
469
470
471
472
473
474
            prompt_lengths_tensor=prompt_lengths_tensor,
            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
            position_ids=None,
            cu_seqlen_prefill=None,
            prefill_cache_indices=None,
            slot_indices=None,
475
476
            slots=slots,
            cu_slots=cu_slots,
477
478
479
480
481
482
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
            cache_lengths_tensor=None,
            input_lengths_tensor=None,
            adapter_meta=None,
483
484
        )

485
486
487
488
489
490
491
492
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
493
        assert len(pb.requests) > 0
494
495
496
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
        return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)

497
    @tracer.start_as_current_span("filter")
498
499
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
500
501
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
502
        if len(request_ids) == len(self):
503
504
            return self

505
        device = self.block_tables_tensor.device
506

507
508
509
        # New values after filtering
        requests_idx_mapping = {}

510
511
512
        # Used to index into tensors
        indices = []

513
514
515
516
517
        if not has_triton():
            # slots to keep after filtering
            slot_filtering_indices = torch.zeros(
                self.slots.shape[0], dtype=torch.bool, device=device
            )
518

519
        # Create on CPU to only move to GPU once instead of at every copy
520
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
521
522
        max_input_length = 0
        max_current_length = 0
523

524
        requests = []
525
        block_tables = []
526
        all_input_ids = []
527
        input_ids = []
528

529
        prompt_lengths = []
530
        input_lengths = []
531
        cache_lengths = []
532
533
        prefix_offsets = []
        read_offsets = []
534
        cu_slots = [0]
535

536
537
538
        prefilling_mask = []
        prefill_logprob_tokens = []

539
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
540
        top_n_tokens = []
drbh's avatar
drbh committed
541
        adapter_set = set()
542

543
        num_blocks = 0
544
        max_blocks = 0
545
546
        max_slots = 0
        cumulative_slot_tokens = 0
547

548
549
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
550
            indices.append(idx)
551
552
553
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
554

555
556
557
558
            # Prefilling
            request_prefilling = self.prefilling_mask[idx]
            prefilling_mask.append(request_prefilling)

559
560
            # Get length
            request_input_length = self.input_lengths[idx]
561
562
563
564
565
            request_cache_length = self.cache_lengths[idx]
            max_input_length = max(max_input_length, request_input_length)
            max_current_length = max(
                max_current_length, request_cache_length + request_input_length
            )
566

567
568
            all_input_ids.append(self.all_input_ids[idx])

569
            prompt_lengths.append(self.prompt_lengths[idx])
570
            input_lengths.append(request_input_length)
571
            cache_lengths.append(request_cache_length)
572
573
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
574

575
576
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
577

Nicolas Patry's avatar
Nicolas Patry committed
578
            top_n_tokens.append(self.top_n_tokens[idx])
579
            prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
Nicolas Patry's avatar
Nicolas Patry committed
580

Nicolas Patry's avatar
Nicolas Patry committed
581
582
            ADAPTER_TO_INDEX = get_adapter_to_index()
            adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
drbh's avatar
drbh committed
583
584
            adapter_set.add(adapter_index)

585
            request_block_table = self.block_tables[idx]
586
            num_blocks += len(request_block_table)
587
588
            block_tables.append(request_block_table)

589
590
591
592
593
594
595
596
597
598
            start_slot = self.cu_slots[idx]
            end_slot = self.cu_slots[idx + 1]
            slot_length = end_slot - start_slot

            if not has_triton():
                # Set slice
                slot_filtering_indices[start_slot:end_slot] = True

            cu_slots.append(cumulative_slot_tokens + slot_length)

599
600
601
602
603
604
            # Input ids if the request was part of a prefilling batch
            # If the batch was decoding we can index into the tensor directly later
            if self.prefilling:
                input_ids.append(self.input_ids[idx])
            else:
                # Copy to tensor (CPU)
605
                slot_indices[i] = cumulative_slot_tokens + request_cache_length
606

607
            cumulative_slot_tokens += slot_length
608
            max_blocks = max(max_blocks, len(request_block_table))
609
            max_slots = max(max_slots, slot_length)
610

611
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
612
        block_tables_tensor = self.block_tables_tensor[indices]
613
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
614
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
OlivierDehaene's avatar
OlivierDehaene committed
615
616
617
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
618
619
        prompt_lengths_tensor = self.prompt_lengths_tensor[indices]

620
621
622
623
624
625
626
627
628
629
630
631
        cu_slots = torch.tensor(cu_slots, dtype=torch.int64)

        if not has_triton():
            slots = self.slots[slot_filtering_indices]
        else:
            slots = self.slots.new_empty(cumulative_slot_tokens)
            gpu_cu_slots = cu_slots.to(device)
            slots_indexing_start = self.cu_slots.to(device)[indices]
            slots_filtering(
                max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start
            )

632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
        if self.prefilling:
            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
            position_ids = None
            slot_indices = None
            cache_lengths_tensor = None
            input_lengths_tensor = None
            adapter_meta = None
        else:
            # Index into tensors
            input_ids = self.input_ids[indices]
            position_ids = self.position_ids[indices]
            adapter_indices = self.adapter_meta.adapter_indices[indices]
            input_lengths_tensor = self.input_lengths_tensor[indices]
            cache_lengths_tensor = self.cache_lengths_tensor[indices]

            # Move to GPU now that we have the whole tensor
            slot_indices = slot_indices.to(device)

            adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
            adapter_segments = torch.tensor(
                adapter_segments, dtype=torch.int32, device=device
            )
            adapter_meta = AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            )
drbh's avatar
drbh committed
660

661
        return type(self)(
662
663
664
665
666
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
667
            cu_seqlen_prefill=None,
668
            prefill_cache_indices=None,
669
670
671
672
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
673
            cu_slots=cu_slots,
674
675
676
677
            max_input_length=max_input_length,
            max_current_length=max_current_length,
            prefilling=self.prefilling,
            prefilling_mask=prefilling_mask,
678
679
680
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
681
682
683
            prefill_logprob_tokens=prefill_logprob_tokens,
            prompt_lengths=prompt_lengths,
            prompt_lengths_tensor=prompt_lengths_tensor,
684
            input_lengths=input_lengths,
685
            input_lengths_tensor=input_lengths_tensor,
686
687
            cache_lengths=cache_lengths,
            cache_lengths_tensor=cache_lengths_tensor,
688
689
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
690
691
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
692
            next_token_chooser=next_token_chooser,
693
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
694
695
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
696
            num_blocks=num_blocks,
697
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
698
            speculative_ids=speculative_ids,
699
            adapter_meta=adapter_meta,
700
701
702
703
704
705
706
707
708
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

709
        prefilling = False
710
        num_blocks = 0
711
712
713
714
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
715
716
        max_input_length = 0
        max_current_length = 0
717
718
        for b in batches:
            total_batch_size += len(b)
719
            max_blocks = max(max_blocks, b.max_blocks)
720
            total_slots += len(b.slots)
721
            num_blocks += b.num_blocks
OlivierDehaene's avatar
OlivierDehaene committed
722
723
724
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
725
726
            max_input_length = max(max_input_length, b.max_input_length)
            max_current_length = max(max_current_length, b.max_current_length)
727
728
729
            max_length = max(
                max_length,
                max(
730
                    prompt_length
731
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
732
                    + speculative_length
733
734
                    for prompt_length, stopping_criteria in zip(
                        b.prompt_lengths, b.stopping_criterias
735
736
737
                    )
                ),
            )
738
739
            prefilling = prefilling or b.prefilling

740
741
        slots = batches[0].slots.new_empty(total_slots)
        cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64)
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
        if prefilling:
            input_ids = []
            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
            position_ids = None
            slot_indices = None
            cache_lengths_tensor = None
            input_lengths_tensor = None
            adapter_meta = None
            adapter_segment_builder = None
        else:
            input_ids = batches[0].input_ids.new_empty(total_batch_size)
            position_ids = batches[0].position_ids.new_empty(total_batch_size)
            slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
            input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
                total_batch_size
            )
            cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
                total_batch_size
            )
            total_indices_size = sum(
                b.adapter_meta.adapter_indices.shape[0] for b in batches
            )
            adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
                total_indices_size
            )
            adapter_segment_builder = SegmentConcatBuilder()
            adapter_set = set()
769

770
        prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
771
772
773
774
775
776
777
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
778
        )
Nicolas Patry's avatar
Nicolas Patry committed
779
780
781
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
782

783
        block_tables = []
784
        cache_lengths = []
785
786
        all_input_ids = []

787
        prompt_lengths = []
788
        input_lengths = []
789
790
        prefix_offsets = []
        read_offsets = []
791

792
793
        prefill_logprob_tokens = []

794
        next_token_chooser_parameters = []
795
        fsm_grammar_states = []
796
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
797
        top_n_tokens = []
798
        prefilling_mask = []
799

800
        # Cumulative length
801
        cumulative_batch_size = 0
802
        cumulative_slots = 0
drbh's avatar
drbh committed
803
        cumulative_adapter_indices_size = 0
804
805
806

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
807
808
809
810
811
812
813
814

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

815
816
817
818
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)

            # Copy tensors (GPU)
Nicolas Patry's avatar
Nicolas Patry committed
819
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
820
821
822
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
823

824
825
826
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
827
            prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
828

829
830
831
832
833
834
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
            slots[slots_start_index:slots_end_index] = batch.slots
            cu_slots[start_index + 1 : end_index + 1] = (
                batch.cu_slots[1:] + cumulative_slots
            )
835

836
            if not prefilling:
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
                input_ids[start_index:end_index] = batch.input_ids
                position_ids[start_index:end_index] = batch.position_ids
                slot_indices[start_index:end_index] = (
                    batch.slot_indices + cumulative_slots
                )
                input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
                cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor

                # Copy over adapter indices
                adapter_start_index = cumulative_adapter_indices_size
                adapter_end_index = (
                    cumulative_adapter_indices_size
                    + batch.adapter_meta.adapter_indices.shape[0]
                )
                adapter_indices[adapter_start_index:adapter_end_index] = (
                    batch.adapter_meta.adapter_indices
                )
                cumulative_adapter_indices_size = adapter_end_index
                adapter_set.update(batch.adapter_meta.adapter_set)
                adapter_segment_builder.concat(
                    batch.adapter_meta.adapter_segments,
                    batch.adapter_meta.segment_indices,
                )
            else:
                if isinstance(batch.input_ids, torch.Tensor):
                    batch.input_ids = batch.input_ids.view(-1, 1).tolist()
                input_ids.extend(batch.input_ids)
864

865
            prefilling_mask.extend(batch.prefilling_mask)
866
            block_tables.extend(batch.block_tables)
867
            cache_lengths.extend(batch.cache_lengths)
868
869
            all_input_ids.extend(batch.all_input_ids)

870
            prompt_lengths.extend(batch.prompt_lengths)
871
            input_lengths.extend(batch.input_lengths)
872
873
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
874

875
876
            prefill_logprob_tokens.extend(batch.prefill_logprob_tokens)

877
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
878
            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
879
880
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
881
882
            top_n_tokens.extend(batch.top_n_tokens)

883
            # Update
884
            cumulative_slots += len(batch.slots)
885
            cumulative_batch_size += len(batch)
886

887
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
888
889
890
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
drbh's avatar
drbh committed
891
            tokenizer=batches[0].next_token_chooser.tokenizer,
892
            fsm_grammar_states=fsm_grammar_states,
893
894
        )

895
896
897
898
899
900
        # We skip computing the speculative_ids when the batch size is too large, so
        # we must check that all batches have them, otherwise they must be discarded
        if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches):
            speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
        else:
            speculative_ids = None
Nicolas Patry's avatar
Nicolas Patry committed
901

902
903
904
905
906
907
908
909
        if adapter_segment_builder is not None:
            adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
            adapter_meta = AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            )
drbh's avatar
drbh committed
910

911
        return cls(
912
913
            batch_id=batches[0].batch_id,
            requests=requests,
914
            requests_idx_mapping=requests_idx_mapping,
915
916
            input_ids=input_ids,
            position_ids=position_ids,
917
            cu_seqlen_prefill=None,
918
            prefill_cache_indices=None,
919
920
921
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
922
923
            cache_lengths=cache_lengths,
            cache_lengths_tensor=cache_lengths_tensor,
924
            slots=slots,
925
            cu_slots=cu_slots,
926
927
928
929
            max_input_length=max_input_length,
            max_current_length=max_current_length,
            prefilling=prefilling,
            prefilling_mask=prefilling_mask,
930
931
932
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
933
934
935
            prefill_logprob_tokens=prefill_logprob_tokens,
            prompt_lengths=prompt_lengths,
            prompt_lengths_tensor=prompt_lengths_tensor,
936
            input_lengths=input_lengths,
937
            input_lengths_tensor=input_lengths_tensor,
938
939
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
940
941
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
942
            next_token_chooser=next_token_chooser,
943
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
944
945
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
946
            num_blocks=num_blocks,
947
            max_blocks=max_blocks,
OlivierDehaene's avatar
OlivierDehaene committed
948
            speculative_ids=speculative_ids,
949
950
951
952
953
954
955
956
957
            adapter_meta=adapter_meta,
        )

    def prepare_for_prefill(self):
        # Prepare values if we need to continue prefilling
        # Speculation must be ignored while we prefill even with chunking
        # it simplifies everything
        assert self.speculative_ids is None

958
959
960
961
962
963
964
965
966
967
968
969
        device = self.block_tables_tensor.device

        if isinstance(self.input_ids, list):
            if len(self) > 1:
                input_ids = np.concatenate(self.input_ids, dtype=np.int64)
            else:
                input_ids = self.input_ids[0]
            self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)

        self.input_lengths_tensor = torch.tensor(
            self.input_lengths, dtype=torch.int32, device=device
        )
970
971
972
        cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(len(self) + 1)
        torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0)
        self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32)
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
        self.cache_lengths_tensor = torch.tensor(
            self.cache_lengths, dtype=torch.int32, device=device
        )

        # If the device supports Triton, we can use a fused kernel
        if has_triton():
            self.position_ids = torch.empty(
                len(self.input_ids), dtype=torch.int32, device=device
            )
            self.slot_indices = torch.empty(
                len(self.input_ids), dtype=torch.int64, device=device
            )
            cu_slots_gpu = self.cu_slots.to(device)

            prepare_position_slot_ids(
                self.max_input_length,
                self.cache_lengths_tensor,
                self.cu_seqlen_prefill,
                cu_slots_gpu,
                self.position_ids,
                self.slot_indices,
            )

996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
        sliding_window = get_sliding_windows()
        position_ids = []
        slot_indices = []
        prefill_cache_indices = []
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_cu_outlens = [0]

        # Cumulative length
        cumulative_length = 0
        cumulative_slot_tokens = 0
        prefill_out_cumulative_length = 0

        adapter_indices_list = []
        adapter_set = set()

        for i, (
            r,
            cache_length,
            input_length,
            prompt_length,
            request_prefilling,
            blocks,
        ) in enumerate(
            zip(
                self.requests,
                self.cache_lengths,
                self.input_lengths,
                self.prompt_lengths,
                self.prefilling_mask,
                self.block_tables,
            )
        ):
            next_chunk_length = input_length

1031
1032
1033
1034
1035
1036
            if not has_triton():
                # Position ids
                request_position_ids = torch.arange(
                    cache_length, cache_length + input_length, dtype=torch.int32
                )
                position_ids.append(request_position_ids)
1037

1038
1039
1040
1041
1042
1043
1044
1045
                if not r.slots:
                    request_slots = [
                        s
                        for b in blocks
                        for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
                    ]
                else:
                    request_slots = r.slots
1046

1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
                request_slot_indices = torch.arange(
                    cache_length + cumulative_slot_tokens,
                    cache_length + cumulative_slot_tokens + input_length,
                    dtype=torch.int64,
                )

                slot_indices.append(request_slot_indices)

                # Update
                cumulative_slot_tokens += len(request_slots)
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082

            # Create tensor to slice into the kv tensor in prefill
            if sliding_window is not None:
                request_prefill_cache_indices = torch.arange(
                    cumulative_length + max(0, input_length - sliding_window),
                    cumulative_length + input_length,
                    dtype=torch.int64,
                )

            # Prefill logprobs is ignored if the request is done prefilling
            prefill_logprobs = r.prefill_logprobs and request_prefilling

            all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs

            if prefill_logprobs:
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

            if sliding_window is not None:
                prefill_cache_indices.append(request_prefill_cache_indices)

            ADAPTER_TO_INDEX = get_adapter_to_index()
1083
1084
1085
1086
1087
1088
            if ADAPTER_TO_INDEX:
                adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
                adapter_indices_list.append(
                    torch.full((next_chunk_length,), adapter_index)
                )
                adapter_set.add(adapter_index)
1089
1090
1091
1092

            # Update
            cumulative_length += next_chunk_length

1093
1094
1095
        if not all_prefill_logprobs and not no_prefill_logprobs:
            prefill_head_indices = []
            prefill_next_token_indices = []
1096

1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
            # Cumulative length
            cumulative_length = 0
            prefill_out_cumulative_length = 0

            for i, (
                r,
                input_length,
                request_prefilling,
            ) in enumerate(
                zip(
                    self.requests,
                    self.input_lengths,
                    self.prefilling_mask,
                )
            ):
                # Prefill logprobs is ignored if the request is done prefilling
                prefill_logprobs = r.prefill_logprobs and request_prefilling

                if prefill_logprobs:
                    prefill_head_indices.append(
                        torch.arange(
                            cumulative_length,
                            cumulative_length + input_length,
                            dtype=torch.int64,
                        )
                    )
                    prefill_next_token_indices.append(
                        prefill_out_cumulative_length + input_length - 1
                    )
                    prefill_out_cumulative_length += input_length
                else:
                    prefill_head_indices.append(
                        torch.tensor(
                            [cumulative_length + input_length - 1],
                            dtype=torch.int64,
                        )
                    )
                    prefill_next_token_indices.append(prefill_out_cumulative_length)
                    prefill_out_cumulative_length += 1

                # Update
                cumulative_length += input_length
1139
1140

        if len(self) > 1:
1141
1142
1143
1144
            if position_ids:
                position_ids = torch.cat(position_ids)
            if slot_indices:
                slot_indices = torch.cat(slot_indices)
1145
1146
1147
            if sliding_window is not None:
                prefill_cache_indices = torch.cat(prefill_cache_indices)
        else:
1148
1149
1150
1151
            if position_ids:
                position_ids = position_ids[0]
            if slot_indices:
                slot_indices = slot_indices[0]
1152
1153
1154
            if sliding_window is not None:
                prefill_cache_indices = prefill_cache_indices[0]

1155
1156
1157
1158
        if not has_triton():
            self.position_ids = position_ids.to(device)
            self.slot_indices = slot_indices.to(device)

1159
1160
1161
1162
1163
1164
1165
        self.prefill_cu_outlens = prefill_cu_outlens
        self.prefill_cache_indices = (
            prefill_cache_indices.to(device) if sliding_window is not None else None
        )

        if all_prefill_logprobs:
            prefill_head_indices = None
1166
            prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1
1167
        elif no_prefill_logprobs:
1168
            prefill_head_indices = self.cu_seqlen_prefill[1:] - 1
1169
1170
1171
1172
1173
1174
1175
1176
1177
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.cat(prefill_head_indices).to(device)
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )

        self.prefill_head_indices = prefill_head_indices
        self.prefill_next_token_indices = prefill_next_token_indices
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188

        if adapter_set:
            adapter_indices = torch.cat(adapter_indices_list).to(
                dtype=torch.int64, device=device
            )
            adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
        else:
            adapter_indices = torch.zeros_like(self.input_ids)
            adapter_segments = [0, len(adapter_indices)]
            adapter_segment_indices = [len(adapter_indices) - 1]

1189
1190
1191
        adapter_segments = torch.tensor(
            adapter_segments, dtype=torch.int32, device=device
        )
1192

1193
1194
1195
1196
1197
        self.adapter_meta = AdapterBatchMetadata(
            adapter_indices=adapter_indices,
            adapter_set=adapter_set,
            adapter_segments=adapter_segments,
            segment_indices=adapter_segment_indices,
1198
1199
1200
1201
1202
1203
        )

    def __len__(self):
        return len(self.requests)


1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
ADAPTER_LAYERS = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}


1216
1217
1218
class FlashCausalLM(Model):
    def __init__(
        self,
drbh's avatar
drbh committed
1219
        model_id: str,
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
        model_class,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        speculator: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
        lora_adapter_ids: Optional[list] = [],
        tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
        config_class: PreTrainedTokenizerBase = AutoConfig,
        default_dtype=torch.float16,
        aliases=None,
        # Used for Santacoder override of config
1232
1233
1234
        num_kv_heads: Optional[int] = None,
        # Deepseek V2 uses different QK and V dims.
        head_size: Optional[int] = None,
1235
        skip_special_tokens: bool = True,
1236
        kv_cache_dtype: Optional[torch.dtype] = None,
1237
        support_chunking: bool = True,
1238
    ):
Nicolas Patry's avatar
Nicolas Patry committed
1239
        self.quantize = quantize
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = default_dtype if dtype is None else dtype
        elif SYSTEM == "ipex":
            if hasattr(torch, "xpu") and torch.xpu.is_available():
                device = torch.device(f"xpu:{rank}")
                dtype = default_dtype if dtype is None else dtype
            else:
                device = torch.device("cpu")
                dtype = torch.bfloat16 if dtype is None else dtype
1251
                init_cpu_threads_env(rank_id=rank, world_size=world_size)
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
        else:
            raise NotImplementedError(f"{model_class} is only available on GPU")

        tokenizer = tokenizer_class.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )
        try:
            generation_config = GenerationConfig.from_pretrained(
                model_id, revision=revision, trust_remote_code=trust_remote_code
            )
            if isinstance(generation_config.eos_token_id, (list, set)):
                # TODO Huge hack
                tokenizer._eos_token_ids = set(generation_config.eos_token_id)
        except Exception:
            pass

        config = config_class.from_pretrained(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        config.quantize = quantize
        config.speculator = speculator

        torch.distributed.barrier(group=self.process_group)

1280
        weights_loader = get_loader(quantize, model_id, revision)
1281
1282
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(
1283
1284
1285
1286
1287
1288
            filenames,
            device,
            dtype,
            process_group=self.process_group,
            aliases=aliases,
            weights_loader=weights_loader,
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
        )

        prefix = ""
        model = model_class(prefix, config, weights)
        torch.distributed.barrier(group=self.process_group)

        # VLM models define the config we care about in their text_config
        text_config = getattr(config, "text_config", None)
        if text_config is not None:
            config = text_config
1299
1300
1301
1302
1303
1304

        if getattr(config, "sliding_window", None) is not None:
            set_sliding_window(config.sliding_window)
        else:
            config.sliding_window = None

1305
        self.num_layers = config.num_hidden_layers
1306
        self.num_heads = config.num_attention_heads // self.process_group.size()
1307
1308
        # Validation is done in the model itself
        if num_kv_heads is None:
1309
1310
            num_kv_heads = getattr(config, "num_key_value_heads", None)
            # GPT-2 workaround
1311
            if num_kv_heads is None:
1312
1313
1314
                num_kv_heads = getattr(config, "n_head", None)
        if num_kv_heads is None:
            raise ValueError("Cannot get the number of key/value heads")
1315
1316
1317
1318
1319
1320
        self.num_kv_heads = (
            num_kv_heads // self.process_group.size()
            if num_kv_heads > 1
            else num_kv_heads
        )
        assert self.num_kv_heads > 0
1321
1322

        if head_size is None:
Nicolas Patry's avatar
Nicolas Patry committed
1323
1324
1325
1326
1327
1328
            # Some models use GQA and different sizes for o_proj
            # and q_proj, that allows for that.
            if hasattr(config, "head_dim"):
                self.head_size = config.head_dim
            else:
                self.head_size = config.hidden_size // config.num_attention_heads
1329
1330
        else:
            self.head_size = head_size
1331

1332
        self.cuda_graphs = {}
1333
        self.kv_cache = []
1334
        self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
1335

1336
        if ATTENTION == "flashinfer":
Nicolas Patry's avatar
Nicolas Patry committed
1337
            from text_generation_server.layers.attention.flashinfer import (
1338
1339
                create_prefill_state,
                create_decode_state,
Nicolas Patry's avatar
Nicolas Patry committed
1340
                create_prefill_with_paged_kv_state,
1341
1342
1343
            )

            self.prefill_state = create_prefill_state(device=device)
Nicolas Patry's avatar
Nicolas Patry committed
1344
1345
1346
            self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
                device=device
            )
1347

Nicolas Patry's avatar
Nicolas Patry committed
1348
1349
1350
1351
1352
            self.decode_state = create_decode_state(
                device=device,
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
            )
1353

1354
        super().__init__(
drbh's avatar
drbh committed
1355
            model_id=model_id,
1356
            model=model,
1357
1358
1359
1360
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
1361
1362
            rank=rank,
            world_size=world_size,
1363
            sliding_window=config.sliding_window,
1364
            support_chunking=support_chunking,
1365
1366
1367
1368
1369
1370
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
    def max_past(self) -> int:
        return getattr(self.model, "max_past", None)

    def init_kv_cache(
        self,
        num_blocks: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        self.kv_cache = []
        empty_cache()
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
        self.kv_cache = [
            KVCache(
                num_blocks=num_blocks,
                num_heads=num_heads,
                head_size=head_size,
                dtype=dtype,
                device=device,
            )
            for _ in range(num_layers)
        ]
1395

1396
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
Nicolas Patry's avatar
Nicolas Patry committed
1397
        max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None
Nicolas Patry's avatar
Nicolas Patry committed
1398
        input_lengths = [max_s] * bs
1399
        cache_lengths = [0] * bs
Nicolas Patry's avatar
Nicolas Patry committed
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
        if max_bs is None:
            input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
            position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
            slots = torch.arange(bs, dtype=torch.int64, device=self.device)
            input_lengths_tensor = (
                torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
            )
            cache_lengths_tensor = torch.zeros(
                bs, dtype=torch.int32, device=self.device
            )
            block_tables = torch.arange(
                max_bt, dtype=torch.int32, device=self.device
            ).repeat(bs)
            block_tables = block_tables.reshape((bs, max_bt))
            if ATTENTION == "flashinfer":
                block_tables = block_tables_to_ragged(
                    block_tables=block_tables,
                    input_lengths=input_lengths,
                    cache_lengths=cache_lengths,
                    input_lengths_tensor=input_lengths_tensor,
                    cache_lengths_tensor=cache_lengths_tensor,
                    max_current_length=max_s,
                )
        else:
            if bs > max_bs:
                raise RuntimeError(
                    "Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
                )
            input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
            position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
            if ATTENTION == "flashinfer":
                block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
            else:
                block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs]
            slots = self.cuda_graphs[max_bs]["slots"][:bs]
            input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs]
            cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs]
Nicolas Patry's avatar
Nicolas Patry committed
1437
1438
1439

        if ATTENTION == "flashinfer":
            from text_generation_server.layers.attention.flashinfer import (
1440
1441
1442
1443
1444
1445
1446
1447
1448
                create_decode_state_cuda_graphs,
            )

            block_tables_ptr = torch.zeros(
                bs + 1, dtype=torch.int32, device=self.device
            )
            last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
            state = create_decode_state_cuda_graphs(
                device=input_ids.device,
Nicolas Patry's avatar
Nicolas Patry committed
1449
                block_tables=block_tables,
1450
1451
1452
1453
1454
1455
1456
1457
                block_tables_ptr=block_tables_ptr,
                last_page_len=last_page_len,
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
            )
        else:
            state = None

drbh's avatar
drbh committed
1458
1459
1460
1461
1462
1463
1464
1465
        if (
            hasattr(self.model, "config")
            and hasattr(self.model.config, "model_type")
            and self.model.config.model_type == "qwen2_vl"
        ):
            if position_ids.dim() == 1:
                position_ids = self.model.get_position_ids(input_ids)

1466
1467
1468
1469
1470
1471
1472
1473
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "kv_cache": self.kv_cache,
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths_tensor,
1474
            "cache_lengths": cache_lengths_tensor,
1475
1476
1477
1478
            "state": state,
            "graph": graph,
        }

1479
1480
        torch.cuda.synchronize()
        # Run once outside to warmup
1481
        with self._forward_context(
1482
            block_tables=block_tables,
1483
            cu_seqlen_prefill=None,
Nicolas Patry's avatar
Nicolas Patry committed
1484
            input_lengths_tensor=input_lengths_tensor,
1485
            state=state,
1486
            cache_lengths_tensor=cache_lengths_tensor,
1487
        ):
1488
1489
            seqlen = Seqlen(
                input_lengths=input_lengths_tensor,
1490
                cache_lengths=cache_lengths_tensor,
1491
1492
1493
1494
                cu_seqlen_q=None,
                max_q=1,
                max_k=max_s,
            )
1495
            self.model.forward(
1496
1497
1498
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
1499
                kv_cache=self.kv_cache,
1500
1501
                block_tables=block_tables,
                slots=slots,
1502
                seqlen=seqlen,
1503
                max_s=max_s,
1504
                prefill_cache_indices=None,
1505
1506
                lm_head_indices=None,
            )
1507
            del seqlen
1508
1509
1510
1511

            torch.cuda.synchronize()

            with torch.cuda.graph(graph, pool=MEM_POOL):
1512
1513
                seqlen = Seqlen(
                    input_lengths=input_lengths_tensor,
1514
                    cache_lengths=cache_lengths_tensor,
1515
1516
1517
1518
                    cu_seqlen_q=None,
                    max_q=1,
                    max_k=max_s,
                )
1519
1520
1521
1522
1523
1524
1525
                logits, speculative_logits = self.model.forward(
                    input_ids=input_ids,
                    position_ids=position_ids,
                    cu_seqlen_prefill=None,
                    kv_cache=self.kv_cache,
                    block_tables=block_tables,
                    slots=slots,
1526
                    seqlen=seqlen,
1527
1528
1529
1530
1531
1532
                    max_s=max_s,
                    prefill_cache_indices=None,
                    lm_head_indices=None,
                )
                self.cuda_graphs[bs]["logits"] = logits
                self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
1533
1534
        torch.cuda.synchronize()

1535
1536
1537
1538
1539
1540
    def warmup(
        self,
        batch: FlashCausalLMBatch,
        max_input_tokens: Optional[int],
        max_total_tokens: Optional[int],
    ):
1541
        # The warmup batch is the biggest batch we could ever receive
1542
        self.kv_cache = []
Nicolas Patry's avatar
Nicolas Patry committed
1543
1544
        empty_cache()

1545
1546
1547
1548
1549
1550
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
        dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

1551
        try:
1552
1553
            self.init_kv_cache(
                batch.num_blocks,
1554
1555
1556
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
1557
                self.kv_cache_dtype,
1558
1559
                self.device,
            )
1560

1561
            batch_num_blocks = batch.num_blocks
fxmarty's avatar
fxmarty committed
1562

Nicolas Patry's avatar
Nicolas Patry committed
1563
            num_tokens = batch.to_pb().current_tokens
fxmarty's avatar
fxmarty committed
1564
            if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
xuxzh1's avatar
xuxzh1 committed
1565
1566
                #torch.cuda.tunable.tuning_enable(False)
                pass
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
            synchronize(self.device)
            free_memory = get_free_memory(
                self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM
            )
            real_free_memory = get_free_memory(self.device, MEMORY_FRACTION)
            log_master(
                logger.debug,
                f"Free memory {free_memory/1e9:.2f}GB , (real: {real_free_memory/1e9:.2f}GB",
            )

1577
            _, _batch, _ = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
1578
        except torch.cuda.OutOfMemoryError as e:
1579
            raise RuntimeError(
Nicolas Patry's avatar
Nicolas Patry committed
1580
                f"Not enough memory to handle {num_tokens} prefill tokens. "
1581
                f"You need to decrease `--max-batch-prefill-tokens`"
1582
            ) from e
1583

Nicolas Patry's avatar
Nicolas Patry committed
1584
        synchronize(self.device)
1585
1586
        free_memory = get_free_memory(self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM)
        kv_memory = free_memory
1587
        num_blocks = (
1588
            # Leave 5% for some wiggle room
1589
            int(kv_memory // total_cache_size)
1590
            # Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
drbh's avatar
drbh committed
1591
            + batch_num_blocks
1592
1593
        )

1594
        log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
1595
1596
1597
        if max_total_tokens is None:
            if get_support_chunking():
                model_max_length = self.tokenizer.model_max_length
1598
                max_total_tokens = min(num_blocks * BLOCK_SIZE, model_max_length)
1599
1600
            else:
                max_total_tokens = sum(batch.cache_lengths)
1601
1602

        if max_input_tokens is None:
Nicolas Patry's avatar
Nicolas Patry committed
1603
            max_input_tokens = max_total_tokens - 1
1604
1605
1606
1607

        del _batch, batch
        self.kv_cache = []
        empty_cache()
1608

1609
        self.init_kv_cache(
1610
1611
1612
1613
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
1614
            self.kv_cache_dtype,
1615
1616
1617
            self.device,
        )

fxmarty's avatar
fxmarty committed
1618
1619
1620
1621
1622
        if SYSTEM == "rocm":
            if (
                os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
                or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
            ):
xuxzh1's avatar
xuxzh1 committed
1623
                #torch.cuda.tunable.enable()
1624

fxmarty's avatar
fxmarty committed
1625
                if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
xuxzh1's avatar
xuxzh1 committed
1626
1627
                    #torch.cuda.tunable.tuning_enable(True)
                    pass
fxmarty's avatar
fxmarty committed
1628
1629
1630
1631
1632
1633

                if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None:
                    tuning_sequences = [
                        int(val)
                        for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
                    ]
1634
                elif CUDA_GRAPHS is not None:
fxmarty's avatar
fxmarty committed
1635
                    tuning_sequences = CUDA_GRAPHS
1636
                else:
1637
                    tuning_sequences = [1, 2, 3, 4, 5, 6, 7]
fxmarty's avatar
fxmarty committed
1638
1639
1640

                tunableop_filepath = os.path.join(
                    HUGGINGFACE_HUB_CACHE,
drbh's avatar
drbh committed
1641
                    f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
fxmarty's avatar
fxmarty committed
1642
1643
                )

1644
1645
                log_master(
                    logger.info,
1646
1647
1648
                    f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
                )

xuxzh1's avatar
xuxzh1 committed
1649
1650
1651
                # torch.cuda.tunable.set_filename(
                #     tunableop_filepath, insert_device_ordinal=False
                # )
fxmarty's avatar
fxmarty committed
1652

xuxzh1's avatar
xuxzh1 committed
1653
1654
1655
1656
1657
1658
                # if os.path.isfile(tunableop_filepath):
                #     log_master(
                #         logger.info,
                #         f"The file {tunableop_filepath} already exists and will be reused.",
                #     )
                #     torch.cuda.tunable.read_file(tunableop_filepath)
fxmarty's avatar
fxmarty committed
1659
1660
1661

                os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)

xuxzh1's avatar
xuxzh1 committed
1662
1663
1664
1665
1666
1667
                # for seqlen in tuning_sequences:
                #     log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
                #     self.tunableop_warmup(seqlen)
                #     torch.cuda.tunable.write_file(tunableop_filepath)
                # if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
                #     torch.cuda.tunable.tuning_enable(False)
fxmarty's avatar
fxmarty committed
1668
            else:
1669
1670
1671
                log_master(
                    logger.info,
                    "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
fxmarty's avatar
fxmarty committed
1672
1673
                )

1674
        if CUDA_GRAPHS:
1675
            try:
1676
1677
1678
                log_master(
                    logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
                )
1679
                # Warmup cuda graphs
1680
                for bs in CUDA_GRAPHS:
1681
1682
1683
1684
1685
1686
1687
1688
                    synchronize(self.device)
                    free_memory = get_free_memory(
                        self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM
                    )
                    log_master(
                        logger.debug,
                        f"Free RAM before cuda graph {bs} {free_memory / 1e9:.2f}GB",
                    )
1689
                    if self.speculate is None or self.speculate + 1 <= bs:
1690
                        self.cuda_graph_warmup(bs, max_total_tokens, max_total_tokens)
1691
1692
1693
1694
1695
1696
1697
1698
1699
                empty_cache()
                synchronize(self.device)
                free_memory = get_free_memory(
                    self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM
                )
                log_master(
                    logger.debug,
                    f"Free RAM after cuda graphs {free_memory / 1e9:.2f}GB",
                )
OlivierDehaene's avatar
OlivierDehaene committed
1700
            except torch.cuda.OutOfMemoryError:
1701
                logger.exception("Decode cuda graph warmup failed")
1702
        else:
1703
1704
1705
            log_master(
                logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
            )
1706

1707
1708
1709
        assert max_input_tokens is not None
        assert max_total_tokens is not None
        return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
1710

fxmarty's avatar
fxmarty committed
1711
1712
1713
1714
1715
    def tunableop_warmup(self, seqlen: int):
        input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
        slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)

fxmarty's avatar
fxmarty committed
1716
1717
        # Dummy value, some models (starcoder2) don't accept `None`.
        input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
1718
1719
1720
        cache_lengths_tensor = torch.zeros(
            seqlen, dtype=torch.int32, device=self.device
        )
1721
1722
1723
        cu_seqlen_prefill = torch.tensor(
            [0, seqlen], device=self.device, dtype=torch.int32
        )
1724
        max_s = seqlen
1725
1726
        seqlen = Seqlen(
            input_lengths=input_lengths,
1727
            cache_lengths=cache_lengths_tensor,
1728
1729
1730
1731
            cu_seqlen_q=cu_seqlen_prefill,
            max_q=1,
            max_k=seqlen,
        )
fxmarty's avatar
fxmarty committed
1732

fxmarty's avatar
fxmarty committed
1733
1734
1735
1736
        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
1737
            cu_seqlen_prefill=cu_seqlen_prefill,
1738
            kv_cache=self.kv_cache,
fxmarty's avatar
fxmarty committed
1739
            block_tables=None,
1740
            seqlen=seqlen,
fxmarty's avatar
fxmarty committed
1741
            slots=slots,
1742
            max_s=max_s,
fxmarty's avatar
fxmarty committed
1743
            lm_head_indices=None,
1744
            prefill_cache_indices=None,
fxmarty's avatar
fxmarty committed
1745
1746
        )

1747
    def forward(
drbh's avatar
drbh committed
1748
        self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
1749
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1750
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
1751
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
1752
1753
1754
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
1755
            kv_cache = self.kv_cache
OlivierDehaene's avatar
OlivierDehaene committed
1756
1757
1758
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
1759
            max_s = batch.max_current_length
OlivierDehaene's avatar
OlivierDehaene committed
1760
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
1761
1762
1763

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
1764
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
1765
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
1766
1767
1768
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1769
1770
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
1771
1772
1773
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
1774
1775
1776
1777

            # Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,
            # then update the slots with the additional indices to ensure we're grabbing the ones that have been
            # allocated
1778
1779
1780
            slot_indices = (
                batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
1781
            slots = batch.slots[slot_indices]
1782

OlivierDehaene's avatar
OlivierDehaene committed
1783
1784
1785
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
1786
1787
            cache_lengths_tensor = (
                batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
Nicolas Patry's avatar
Nicolas Patry committed
1788
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1789
1790

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
1791
1792
1793
1794
1795
1796
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
1797
1798
1799
1800
1801
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
1802
1803
1804
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
1805
            kv_cache = self.kv_cache
OlivierDehaene's avatar
OlivierDehaene committed
1806
1807
1808
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
1809
1810
            cache_lengths_tensor = batch.cache_lengths_tensor
            max_s = batch.max_current_length
OlivierDehaene's avatar
OlivierDehaene committed
1811
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
1812

1813
1814
1815
1816
1817
1818
        if cu_seqlen_prefill is None and self.max_past() is not None:
            # In decode, not prefill, we're actually overwriting the KV-cache
            # in a circular buffer mode.
            # This makes sure the max_s for the decode pass is correct.
            max_s = min(self.max_past(), max_s)

1819
        bs = input_ids.shape[0]
OlivierDehaene's avatar
OlivierDehaene committed
1820
1821
1822
1823
1824
1825
1826
1827
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None

        if cu_seqlen_prefill is not None or cuda_graph is None:
1828
            if ATTENTION == "flashinfer":
Nicolas Patry's avatar
Nicolas Patry committed
1829
1830
1831
                block_tables = block_tables_to_ragged(
                    block_tables=block_tables,
                    input_lengths=batch.input_lengths,
1832
                    cache_lengths=batch.cache_lengths,
1833
1834
1835
                    input_lengths_tensor=batch.input_lengths_tensor,
                    cache_lengths_tensor=batch.cache_lengths_tensor,
                    max_current_length=batch.max_current_length,
Nicolas Patry's avatar
Nicolas Patry committed
1836
                )
1837
            with self._forward_context(
1838
                block_tables=block_tables,
1839
                cu_seqlen_prefill=cu_seqlen_prefill,
1840
                input_lengths_tensor=input_lengths,
1841
                cache_lengths_tensor=cache_lengths_tensor,
1842
            ):
1843
1844
                seqlen = Seqlen(
                    input_lengths=input_lengths,
1845
                    cache_lengths=cache_lengths_tensor,
1846
                    cu_seqlen_q=cu_seqlen_prefill,
1847
1848
                    max_q=batch.max_input_length,
                    max_k=batch.max_current_length,
1849
                )
1850
1851
1852
1853
1854
1855
1856
                logits, speculative_logits = self.model.forward(
                    input_ids=input_ids,
                    position_ids=position_ids,
                    cu_seqlen_prefill=cu_seqlen_prefill,
                    kv_cache=kv_cache,
                    block_tables=block_tables,
                    slots=slots,
1857
                    seqlen=seqlen,
1858
1859
1860
1861
1862
1863
1864
1865
                    max_s=max_s,
                    prefill_cache_indices=batch.prefill_cache_indices,
                    lm_head_indices=lm_head_indices,
                    adapter_data=adapter_data,
                )
                if batch.prefill_cache_indices is not None:
                    batch.prefill_cache_indices = None
                return logits, speculative_logits
1866
1867
1868
1869

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
drbh's avatar
drbh committed
1870
        cuda_graph["position_ids"][: position_ids.shape[-1]] = position_ids
Nicolas Patry's avatar
Nicolas Patry committed
1871
1872
1873
1874
        if ATTENTION == "flashinfer":
            block_tables = block_tables_to_ragged(
                block_tables=block_tables,
                input_lengths=batch.input_lengths,
1875
                cache_lengths=batch.cache_lengths,
1876
1877
1878
                input_lengths_tensor=batch.input_lengths_tensor,
                cache_lengths_tensor=batch.cache_lengths_tensor,
                max_current_length=batch.max_current_length,
Nicolas Patry's avatar
Nicolas Patry committed
1879
            )
1880
            # assert block_tables.shape[0] >= slots.shape[0]
Nicolas Patry's avatar
Nicolas Patry committed
1881
1882
1883
1884
1885
            cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
        else:
            cuda_graph["block_tables"][
                : block_tables.shape[0], : block_tables.shape[1]
            ] = block_tables
1886
1887
1888
1889

        # XXX: This is working only because block 0 is reserved for the healthcheck
        # so it doesn't matter if we override it with bogus values.
        cuda_graph["slots"].fill_(0)
1890
1891
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
1892
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
1893
1894
1895
1896
        cuda_graph["cache_lengths"].zero_()
        cuda_graph["cache_lengths"][
            : cache_lengths_tensor.shape[0]
        ] = cache_lengths_tensor
1897

1898
        with self._forward_context(
Nicolas Patry's avatar
Nicolas Patry committed
1899
            block_tables=cuda_graph["block_tables"],
1900
            cu_seqlen_prefill=None,
Nicolas Patry's avatar
Nicolas Patry committed
1901
            input_lengths_tensor=cuda_graph["input_lengths"],
1902
            cache_lengths_tensor=cuda_graph["cache_lengths"],
1903
            state=cuda_graph["state"],
1904
1905
1906
1907
        ):
            # Replay the graph
            cuda_graph["graph"].replay()

1908
        # Slice output to the correct shape
1909
1910
1911
1912
1913
1914
1915
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
1916
1917
1918
1919

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
1920
1921
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
1922
1923
1924
1925
        prefill = batch.prefilling
        if prefill:
            batch.prepare_for_prefill()

1926
        prefill_logprobs = batch.prefill_next_token_indices is not None
1927

drbh's avatar
drbh committed
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
        # Update adapter indices for speculative tokens (if present)
        adapter_meta = batch.adapter_meta
        if batch.speculative_ids is not None:
            B, speculative_length = batch.speculative_ids.shape
            new_length = speculative_length + 1
            adapter_indices = (
                adapter_meta.adapter_indices.unsqueeze(-1)
                .expand(B, new_length)
                .reshape(-1)
            )
            adapter_segments = adapter_meta.adapter_segments * new_length
            adapter_meta = AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_meta.adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_meta.segment_indices,
            )

        # Assign pointers to adapter weights
        # TODO(travis): don't update this if indices haven't changed
        adapter_data = AdapterBatchData.from_meta(
            adapter_meta,
            self.layer_to_adapter_weights,
            prefill,
            batch.prefill_head_indices,
        )

        out, speculative_logits = self.forward(batch, adapter_data)
1956

1957
1958
        if prefill:
            next_token_logits = (
1959
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
1960
            )
Nicolas Patry's avatar
Nicolas Patry committed
1961
1962
            if speculative_logits is not None:
                speculative_logits = (
OlivierDehaene's avatar
OlivierDehaene committed
1963
1964
1965
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
Nicolas Patry's avatar
Nicolas Patry committed
1966
                )
1967
1968
1969
1970
            if len(batch) > 1 and prefill_logprobs:
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
1971
        else:
1972
            prefill_logprobs = None
1973
1974
            next_token_logits = out

1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
        finished_prefilling = True
        next_chunk_lengths = []
        current_prefilling_mask = batch.prefilling_mask
        if prefill:
            if get_support_chunking():
                next_prefilling_mask = []
                # Budget in tokens for the next batch
                # We remove (len(batch) - 1) to always have enough space for at least a single decode
                # for the remaining requests -1 because the first request does not need to be removed from the budget
                # (ex: you have one request in the batch, you want it to take the full budget not budget -1)
                batch_budget = get_max_prefill_tokens() - (len(batch) - 1)
                # We reverse to prioritize older requests
                # zip() is not reversible so reverse the underlying lists instead
                for cache_length, input_length, prompt_length in zip(
                    reversed(batch.cache_lengths),
                    reversed(batch.input_lengths),
                    reversed(batch.prompt_lengths),
                ):
                    remaining_prefill_tokens = max(
                        prompt_length - cache_length - input_length, 0
                    )
                    if remaining_prefill_tokens > 0:
                        next_chunk_length = max(
                            min(remaining_prefill_tokens, batch_budget), 1
                        )
                        batch_budget -= next_chunk_length
                        finished_prefilling = False
                        next_prefilling_mask.append(True)
                    else:
                        # FIXME: use true number of accepted tokens instead of 1
                        # Since speculation will be turned off, this is always true
                        next_chunk_length = 1
                        next_prefilling_mask.append(False)
                    next_chunk_lengths.append(next_chunk_length)

                # Reverse back the obtained values²
                next_chunk_lengths.reverse()
                next_prefilling_mask.reverse()
            else:
                # The model does not support chunking
                # We know we only do a single prefill
                finished_prefilling = True
                next_prefilling_mask = [False] * len(batch)

            batch.prefilling = not finished_prefilling
            batch.prefilling_mask = next_prefilling_mask

Nicolas Patry's avatar
Nicolas Patry committed
2022
        speculate = get_speculate()
OlivierDehaene's avatar
OlivierDehaene committed
2023
2024
2025
2026
2027
2028
2029
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
2030
            batch.all_input_ids_tensor[:, : batch.max_current_length],
OlivierDehaene's avatar
OlivierDehaene committed
2031
            next_token_logits,
Nicolas Patry's avatar
Nicolas Patry committed
2032
            speculate,
OlivierDehaene's avatar
OlivierDehaene committed
2033
2034
            batch.speculative_ids,
            speculative_logits,
2035
2036
        )

Nicolas Patry's avatar
Nicolas Patry committed
2037
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
2038
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
2039
2040
        )

2041
2042
2043
        # Since we are done prefilling, all the tensors that were concatenating values for all the requests
        # instantly become of shape [BATCH_SIZE]
        if prefill and finished_prefilling:
2044
            indices = batch.cu_seqlen_prefill[1:] - 1
drbh's avatar
drbh committed
2045
            batch.position_ids = batch.position_ids[(..., indices)]
2046
2047
2048
2049
            batch.slot_indices = batch.slot_indices[indices]
            batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
                indices
            ]
2050

2051
        # Zipped iterator
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
        iterator = zip(
            batch.requests,
            batch.prompt_lengths,
            batch.cache_lengths,
            batch.input_lengths,
            batch.all_input_ids,
            accepted_ids,
            current_prefilling_mask,
            batch.prefilling_mask,
        )
2062

2063
2064
2065
2066
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

2067
        # For each member of the batch
2068
        # Cumulative length
2069
2070
        cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
        torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
        cumulative_length = 0
        for i, (
            request,
            prompt_length,
            cache_length,
            input_length,
            all_input_ids,
            n_accepted_ids,
            request_was_prefilling,
            request_is_prefilling,
        ) in enumerate(iterator):
            # Used to gather prefill logprobs
            # Copy batch.all_input_ids_tensor to prefill_token_indices
            if request.prefill_logprobs and request_was_prefilling:
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
2088

2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
                # Logprobs generated by the model are for the next token
                # So we need to translate the id tensor by 1
                ids = batch.all_input_ids_tensor[
                    i, cache_length + 1 : cache_length + input_length + 1
                ]
                if len(batch) > 1:
                    prefill_tokens_indices[out_start_index:out_end_index] = ids
                else:
                    # Set prefill_tokens_indices to the correct slice
                    prefill_tokens_indices = ids

2100
2101
            # If the device does not support triton, we copy one by one
            if not request_is_prefilling and not has_triton():
2102
                # Only save tokens if we are done prefilling for this request
2103
2104
2105
2106
2107
2108
2109
                batch.all_input_ids_tensor[
                    i,
                    batch.cache_lengths_tensor[i]
                    + batch.input_lengths[i] : batch.cache_lengths_tensor[i]
                    + batch.input_lengths[i]
                    + accepted_ids[i],
                ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
2110
2111
            cumulative_length += input_length

2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
        # If the device support triton, we can use a fused kernel
        if has_triton():
            copy_next_input_ids_inplace(
                speculate + 1,
                batch.all_input_ids_tensor,
                batch.cache_lengths_tensor,
                batch.input_lengths_tensor,
                batch.prompt_lengths_tensor,
                next_input_ids,
                cu_accepted_ids,
            )

drbh's avatar
drbh committed
2124
        # Update values
2125
2126
        # These values can be updated without a GPU -> CPU sync
        if not prefill or (prefill and finished_prefilling):
2127
            batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
2128
            batch.speculative_ids = speculative_ids
2129
2130
2131
            batch.position_ids += accepted_ids
            batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
            batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
2132
            batch.slot_indices += accepted_ids
2133

2134
        if prefill and prefill_logprobs:
2135
2136
2137
            # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size))
            torch.log_softmax(out, -1, out=out)
            prefill_logprobs_tensor = out
2138
2139
2140
2141
2142
2143
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
        # Does a GPU <-> CPU sync internally
        if prefill and finished_prefilling:
            # adjust segment lengths to account for all request lengths being 1 during decoding
            adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
            batch.adapter_meta.adapter_segments = torch.tensor(
                adapter_segments,
                dtype=torch.int32,
                device=batch.adapter_meta.adapter_segments.device,
            )

2154
2155
        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
2156
        next_token_ids = next_input_ids.tolist()
2157
        accepted_ids = accepted_ids.tolist()
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198

        # Update values if we need to continue prefilling
        # This represents the `else` case of the `Update values` if above
        # but since this require the `next_token_ids` to be on CPU, it is better to do it here
        if prefill and not finished_prefilling:
            # Speculation must be ignored while we prefill even with chunking
            # it simplifies everything
            assert batch.speculative_ids is None

            all_postfix_ids = []
            for i, (
                request_prefilling,
                next_token_id,
                all_input_ids,
                cache_length,
                input_length,
                next_chunk_length,
            ) in enumerate(
                zip(
                    batch.prefilling_mask,
                    next_token_ids,
                    batch.all_input_ids,
                    batch.cache_lengths,
                    batch.input_lengths,
                    next_chunk_lengths,
                )
            ):
                if request_prefilling:
                    next_cache_length = cache_length + input_length
                    # Get new prompt IDs to prefill
                    postfix_ids = all_input_ids[
                        next_cache_length : next_cache_length + next_chunk_length
                    ]
                else:
                    # This request is done prefilling, the new id is the one selected the sampling method
                    postfix_ids = [next_token_id]

                all_postfix_ids.append(postfix_ids)

            batch.input_ids = all_postfix_ids

2199
        start_decode = time.time_ns()
2200

2201
2202
2203
2204
        # Results
        generations: List[Generation] = []
        stopped = True

2205
2206
2207
        # Zipped iterator
        iterator = zip(
            batch.requests,
2208
2209
            batch.prompt_lengths,
            batch.cache_lengths,
2210
            batch.input_lengths,
2211
2212
            batch.prefix_offsets,
            batch.read_offsets,
2213
2214
            batch.stopping_criterias,
            batch.all_input_ids,
2215
2216
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
2217
            batch.top_n_tokens,
2218
2219
            current_prefilling_mask,
            batch.prefilling_mask,
Nicolas Patry's avatar
Nicolas Patry committed
2220
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
2221
2222
            batch_top_token_ids,
            batch_top_token_logprobs,
2223
2224
        )

2225
2226
        # Reset max_input_length
        batch.max_input_length = 0
2227
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
2228
        index = 0
2229
2230
        for i, (
            request,
2231
2232
            prompt_length,
            cache_length,
2233
            input_length,
2234
2235
            prefix_offset,
            read_offset,
2236
2237
            stopping_criteria,
            all_input_ids,
2238
2239
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
2240
            top_n_tokens,
2241
2242
            request_was_prefilling,
            request_is_prefilling,
Nicolas Patry's avatar
Nicolas Patry committed
2243
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
2244
2245
            top_token_ids,
            top_token_logprobs,
2246
        ) in enumerate(iterator):
2247
2248
2249
2250
2251
            # Compute logprobs first as, even though we might skip the token,
            # it can still be required to compute the logprobs
            # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
            # this state to be stable
            if request.id % self.world_size == self.rank:
2252
                # Prefill
2253
                if request_was_prefilling and request.prefill_logprobs:
2254
2255
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
                    if not request_is_prefilling:
                        # The request is dones prefilling, meaning that we started generating new tokens
                        # The last logprob is a logprob for a generated token that was not part of the prompt
                        # We need to remove it
                        out_end_index -= 1

                    request_prefill_logprobs = prefill_logprobs[
                        out_start_index:out_end_index
                    ]
                    # Logprobs generated by the model are for the next token
                    # So we need to translate the id tensor by 1
                    prefill_token_ids = all_input_ids[
                        cache_length + 1 : cache_length + input_length + 1
                    ]

                    past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]

                    if past_prefill_logprob_tokens is None:
                        # add nan for cached prompt tokens/first token
                        request_prefill_logprobs = [float("nan")] * (
                            cache_length + 1
                        ) + request_prefill_logprobs
                        prefill_token_ids = (
                            all_input_ids[: cache_length + 1] + prefill_token_ids
                        )
2281

2282
                    prefill_texts = self.tokenizer.batch_decode(
2283
                        prefill_token_ids,
2284
2285
2286
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
2287

2288
2289
                    prefill_logprob_tokens = Tokens(
                        prefill_token_ids,
OlivierDehaene's avatar
OlivierDehaene committed
2290
2291
2292
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
2293
                    )
2294
2295
2296
2297
2298
2299
                    if past_prefill_logprob_tokens is not None:
                        prefill_logprob_tokens = (
                            past_prefill_logprob_tokens + prefill_logprob_tokens
                        )

                    batch.prefill_logprob_tokens[i] = prefill_logprob_tokens
2300
                else:
2301
2302
2303
2304
2305
2306
2307
2308
                    batch.prefill_logprob_tokens[i] = None

            # If it is, the tokens we decoded should be ignored
            if request_is_prefilling:
                # Make sure that we do not stop as even though this request did not create a token, it is still
                # processing
                stopped = False
                new_input_length = next_chunk_lengths[i]
2309
                new_cache_length = cache_length + input_length
2310
            else:
2311
2312
                new_input_length = 1
                new_cache_length = cache_length + input_length + n_accepted_ids - 1
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
                # Append next token to all tokens
                next_token_texts = []
                left = 0

                if n_accepted_ids > 1:
                    log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")

                current_stopped = False
                for j in range(index, index + n_accepted_ids):
                    # Generated token
                    next_token_id = next_token_ids[j]
                    all_input_ids.append(next_token_id)
                    next_token_text, prefix_offset, read_offset = self.decode_token(
                        all_input_ids,
                        prefix_offset,
                        read_offset,
                    )
                    next_token_texts.append(next_token_text)

                    stop, reason = stopping_criteria(
                        next_token_id,
                        next_token_text,
                    )

                    if stop:
                        left = index + n_accepted_ids - j - 1
                        current_stopped = True
                        break
                    else:
                        current_stopped = False
                stopped = stopped and current_stopped

                _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
                _next_token_logprobs = next_token_logprobs[
                    index : index + n_accepted_ids - left
                ]

                # Shard generations
                # All generations will be appended in the rust sharded client
                if request.id % self.world_size == self.rank:
                    if stop:
                        # Decode generated tokens
                        output_text, _, _ = self.decode_token(
                            all_input_ids,
                            prefix_offset=len(all_input_ids)
                            - stopping_criteria.current_tokens
                            - 1,
                            read_offset=len(all_input_ids)
                            - stopping_criteria.current_tokens,
                            skip_special_tokens=True,
Nicolas Patry's avatar
Nicolas Patry committed
2363
                        )
2364
2365
2366
2367
2368
                        generated_text = GeneratedText(
                            output_text,
                            stopping_criteria.current_tokens,
                            reason,
                            seed if do_sample else None,
Nicolas Patry's avatar
Nicolas Patry committed
2369
                        )
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
                    else:
                        generated_text = None

                    if top_n_tokens > 0:
                        all_top_tokens = []
                        for top_token_ids, top_token_logprobs in zip(
                            top_token_ids, top_token_logprobs
                        ):
                            toptoken_texts = self.tokenizer.batch_decode(
                                top_token_ids,
                                clean_up_tokenization_spaces=False,
                                skip_special_tokens=False,
                            )
                            special_toptokens = [
                                token_id in self.all_special_ids
                                for token_id in top_token_ids
                            ]
                            top_tokens = Tokens(
                                top_token_ids,
                                top_token_logprobs,
                                toptoken_texts,
                                special_toptokens,
                            )
                            all_top_tokens.append(top_tokens)
                        top_tokens = all_top_tokens
                    else:
                        top_tokens = None

                    generation = Generation(
                        request.id,
                        batch.prefill_logprob_tokens[i],
                        Tokens(
                            _next_token_ids,
                            _next_token_logprobs,
                            next_token_texts,
                            [nid in self.all_special_ids for nid in _next_token_ids],
                        ),
                        generated_text,
                        top_tokens,
                    )
2410

2411
                    generations.append(generation)
2412

2413
2414
2415
2416
2417
2418
2419
2420
                # accept each new token for this specific request since we may
                # have more than one new token per request with speculative decoding
                for next_token_id in _next_token_ids:
                    batch.next_token_chooser = (
                        batch.next_token_chooser.advance_grammar_single(
                            i, next_token_id
                        )
                    )
drbh's avatar
drbh committed
2421

2422
            # Update values
2423
            index += n_accepted_ids
2424
2425
2426
2427
            batch.cache_lengths[i] = new_cache_length
            batch.max_input_length = max(batch.max_input_length, new_input_length)
            batch.input_lengths[i] = new_input_length
            current_length = new_cache_length + new_input_length
2428
2429
            batch.max_current_length = max(batch.max_current_length, current_length)

2430
2431
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
2432
2433
            batch.all_input_ids[i] = all_input_ids

2434
2435
        if stopped:
            # No need to return a batch if we know that all requests stopped
2436
2437
2438
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
2439

2440
2441
2442
2443
2444
2445
2446
        if prefill and finished_prefilling:
            # We do not need prefill tensors anymore
            batch.cu_seqlen_prefill = None
            batch.prefill_cache_indices = None
            batch.prefill_cu_outlens = None
            batch.prefill_head_indices = None
            batch.prefill_next_token_indices = None
2447

2448
2449
2450
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)
2451
2452
2453
2454
2455
2456

    def _forward_context(
        self,
        *,
        block_tables: torch.Tensor,
        cu_seqlen_prefill: Optional[torch.Tensor],
Nicolas Patry's avatar
Nicolas Patry committed
2457
        input_lengths_tensor: torch.Tensor,
2458
        cache_lengths_tensor: torch.Tensor,
2459
2460
        state: Optional[Any] = None,
    ) -> ContextManager:
2461
        if ATTENTION != "flashinfer":
2462
2463
            return nullcontext()

Nicolas Patry's avatar
Nicolas Patry committed
2464
        from text_generation_server.layers.attention.flashinfer import (
2465
            use_decode_state,
Nicolas Patry's avatar
Nicolas Patry committed
2466
            use_prefill_with_paged_kv_state,
2467
2468
2469
        )

        if cu_seqlen_prefill is not None:
Nicolas Patry's avatar
Nicolas Patry committed
2470
2471
2472
2473
2474
            return use_prefill_with_paged_kv_state(
                state=(
                    state if state is not None else self.prefill_with_paged_kv_state
                ),
                block_tables=block_tables,
2475
                cu_seqlens=cu_seqlen_prefill,
2476
                input_lengths=input_lengths_tensor + cache_lengths_tensor,
2477
2478
2479
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
Nicolas Patry's avatar
Nicolas Patry committed
2480
                page_size=BLOCK_SIZE,
2481
2482
                dtype=self.dtype,
                window_left=self.sliding_window,
2483
2484
            )
        else:
Nicolas Patry's avatar
Nicolas Patry committed
2485
            assert input_lengths_tensor is not None
2486
2487
            return use_decode_state(
                state=state if state is not None else self.decode_state,
2488
                input_lengths=input_lengths_tensor + cache_lengths_tensor,
Nicolas Patry's avatar
Nicolas Patry committed
2489
                block_tables=block_tables,
2490
2491
2492
2493
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                page_size=BLOCK_SIZE,
2494
                kv_cache_dtype=self.kv_cache_dtype,
2495
2496
                dtype=self.dtype,
                window_left=self.sliding_window,
2497
            )