flash_causal_lm.py 68 KB
Newer Older
1
from contextlib import nullcontext
2
import math
3
import os
4
import time
5
6
7
import torch
import torch.distributed

8
9
import numpy as np

10
from loguru import logger
11
12
from dataclasses import dataclass
from opentelemetry import trace
13
14
15
16
17
18
from transformers import (
    PreTrainedTokenizerBase,
    AutoConfig,
    AutoTokenizer,
    GenerationConfig,
)
19
from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict
fxmarty's avatar
fxmarty committed
20

drbh's avatar
drbh committed
21
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
fxmarty's avatar
fxmarty committed
22
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
Daniël de Kok's avatar
Daniël de Kok committed
23
from text_generation_server.utils.chunks import concat_text_chunks
Nicolas Patry's avatar
Nicolas Patry committed
24
from text_generation_server.utils.import_utils import SYSTEM
OlivierDehaene's avatar
OlivierDehaene committed
25
from text_generation_server.models import Model
26
from text_generation_server.utils.log import log_master
27
from text_generation_server.utils.tokens import batch_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
28
from text_generation_server.utils.speculate import get_speculate
29
30
31
32
33
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
)
34
35
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
36
    Tokens,
37
38
39
40
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
Nicolas Patry's avatar
Nicolas Patry committed
41
42
from text_generation_server.models.globals import (
    MEM_POOL,
43
    ATTENTION,
44
    BLOCK_SIZE,
Nicolas Patry's avatar
Nicolas Patry committed
45
46
47
    CUDA_GRAPHS,
    get_adapter_to_index,
)
48
from text_generation_server.layers.attention import Seqlen
49
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
50
from text_generation_server.utils.dist import MEMORY_FRACTION
51
from text_generation_server.utils.quantization import get_loader
drbh's avatar
drbh committed
52
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
53

Nicolas Patry's avatar
Nicolas Patry committed
54
from text_generation_server.utils.import_utils import (
Nicolas Patry's avatar
Nicolas Patry committed
55
56
57
    empty_cache,
    synchronize,
    get_free_memory,
Nicolas Patry's avatar
Nicolas Patry committed
58
59
)

Nicolas Patry's avatar
Nicolas Patry committed
60
61
tracer = trace.get_tracer(__name__)

62
63
64
65
66
67
68
69
70
71
72
73
74
75

# Will be set in init
SLIDING_WINDOW: Optional[int] = None


def set_sliding_window(sliding_window: int):
    global SLIDING_WINDOW
    SLIDING_WINDOW = sliding_window


def get_sliding_windows() -> int:
    global SLIDING_WINDOW
    return SLIDING_WINDOW

76

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
def init_cpu_threads_env(rank_id: int, world_size: int):
    import importlib.util

    if importlib.util.find_spec("numa") is not None:
        import numa
        import psutil

        nodes = numa.get_max_node() + 1
        rank_per_node = math.ceil(world_size / nodes)
        num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes)
        node_id = int(rank_id / rank_per_node)
        rank_offset_per_node = rank_id % rank_per_node
        if os.getenv("OMP_NUM_THREADS") is None:
            num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1)
        else:
            num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS"))
        if len(numa.get_membind()) == nodes:
            numa.set_membind([node_id])
        torch.set_num_threads(num_cpus_per_rank)
        if len(numa.get_affinity(0)) == psutil.cpu_count(logical=True):
            cpu_start = num_cpus_per_rank * rank_offset_per_node
            numa.set_affinity(
                0,
                list(numa.node_to_cpus(node_id))[
                    cpu_start : cpu_start + num_cpus_per_rank
                ],
            )
        logger.info(f"affinity={numa.get_affinity(0)}, membind = {numa.get_membind()}")


107
108
109
110
@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
111
112
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
113
114

    # Decoder values
115
116
    input_ids: torch.Tensor
    position_ids: torch.Tensor
117
    speculative_ids: Optional[torch.Tensor]
118

119
120
121
122
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
123
124
125
    # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
    # as we only keep SLIDING_WINDOW values instead of the whole tensor
    prefill_cache_indices: Optional[torch.Tensor]
126
127
128
129
130
131
132
133
134
135

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor

    # list of length b of list of length s_i // block_size
136
    block_tables: List[List[int]]
137
    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
138
    block_tables_tensor: torch.Tensor
139
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
140
    slots: torch.Tensor
141

142
143
    max_seqlen: int

144
145
146
147
148
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

149
150
    # All tokens
    all_input_ids: List[List[int]]
151
    all_input_ids_tensor: torch.Tensor
152
153
154

    # Lengths of all generations present in the batch
    input_lengths: List[int]
155
    input_lengths_tensor: torch.Tensor
156
157
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
158
159

    # Generation helpers
160
    next_token_chooser: HeterogeneousNextTokenChooser
161
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
162
163
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
164

drbh's avatar
drbh committed
165
166
167
    # Adapter metadata for each request
    adapter_meta: AdapterBatchMetadata

168
    # Number of blocks in this batch
169
    num_blocks: int
170
171
    # Maximum number of blocks
    max_blocks: int
172

173
174
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
175
            id=self.batch_id,
176
            request_ids=[r.id for r in self.requests],
177
            size=len(self),
178
            max_tokens=self.num_blocks * BLOCK_SIZE,
179
180
181
        )

    @classmethod
Daniël de Kok's avatar
Daniël de Kok committed
182
183
184
    def batch_tokenized_inputs(
        cls, requests: Iterable[generate_pb2.Request], tokenizer
    ):
185
186
        batch_inputs = []
        max_truncation = 0
187
        for r in requests:
Daniël de Kok's avatar
Daniël de Kok committed
188
            batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
189
190
191
192
193
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]
194
        return batch_tokenized_inputs
195

drbh's avatar
drbh committed
196
197
198
199
200
201
202
203
204
    @classmethod
    def from_tokenized(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        batch_tokenized_inputs,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
205
        sliding_window = get_sliding_windows()
206
        position_ids = []
207
        cu_seqlen_prefill = [0]
208
209
        start_slots = []
        slot_indices = []
210
        prefill_cache_indices = []
211
212

        input_lengths = []
213
214
        prefix_offsets = []
        read_offsets = []
215
        all_input_ids = []
216
        requests_idx_mapping = {}
217

218
219
220
221
222
223
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

224
        next_token_chooser_parameters = []
225
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
226
        top_n_tokens = []
227

drbh's avatar
drbh committed
228
229
230
        adapter_indices_list = []
        adapter_set = set()

231
232
        # Cumulative length
        cumulative_length = 0
233
        cumulative_max_length = 0
234
        prefill_out_cumulative_length = 0
235

236
        num_blocks = 0
237
        max_seqlen = 0
238
        max_length = 0
239
        max_blocks = 0
240

241
242
243
        block_tables = []
        slots = []

244
        # Parse batch
245
246
247
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
248
249
250
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

251
            tokenized_input = tokenized_input[-r.truncate :]
252
253
254
255
256
            if (
                tokenized_input[0] == tokenizer.bos_token_id
                and tokenized_input[1] == tokenizer.bos_token_id
            ):
                tokenized_input = tokenized_input[1:]
257

258
259
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
260

261
            prefix_offsets.append(input_length - 5)
262
            read_offsets.append(input_length)
263

264
            all_input_ids.append(tokenized_input)
265
266

            # Position ids
267
268
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
269
270

            # Add cumulative lengths of all previous inputs
271
            cu_seqlen_prefill.append(cumulative_length + input_length)
272

273
            next_token_chooser_parameters.append(r.parameters)
274

275
276
277
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
278
            max_new_tokens = stopping_criteria.max_new_tokens
279
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
280
            top_n_tokens.append(r.top_n_tokens)
281

Nicolas Patry's avatar
Nicolas Patry committed
282
283
            ADAPTER_TO_INDEX = get_adapter_to_index()
            adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
drbh's avatar
drbh committed
284
285
286
            adapter_indices_list.append(torch.full((input_length,), adapter_index))
            adapter_set.add(adapter_index)

287
288
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
289
            speculative_length = get_speculate()
drbh's avatar
drbh committed
290
            speculative_length = 0 if speculative_length is None else speculative_length
Nicolas Patry's avatar
Nicolas Patry committed
291
            total_tokens = input_length + max_new_tokens - 1 + speculative_length
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310

            # blocks and slots can be empty (for example in warmup)
            if not r.blocks:
                needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
                request_blocks = [
                    b for b in range(num_blocks, num_blocks + needed_blocks)
                ]
                request_slots = [
                    s
                    for b in request_blocks
                    for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
                ]
            else:
                request_blocks = r.blocks
                request_slots = r.slots

            block_tables.append(request_blocks)
            slots.extend(request_slots[:total_tokens])
            num_blocks += len(request_blocks)
311
312
313
314
315
316
317
318
319
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

320
321
322
323
324
325
326
327
328
            # Create tensor to slice into the kv tensor in prefill
            if sliding_window is not None:
                request_prefill_cache_indices = torch.arange(
                    cumulative_length + max(0, input_length - sliding_window),
                    cumulative_length + input_length,
                    dtype=torch.int64,
                )
                prefill_cache_indices.append(request_prefill_cache_indices)

329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

349
350
            # Update
            cumulative_length += input_length
351
352
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
353
            max_blocks = max(max_blocks, len(request_blocks))
OlivierDehaene's avatar
OlivierDehaene committed
354
355
356
            max_length = max(
                max_length, input_length + max_new_tokens + speculative_length
            )
357

drbh's avatar
drbh committed
358
359
360
361
        adapter_indices = torch.cat(adapter_indices_list).to(
            dtype=torch.int64, device=device
        )

362
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
363
            next_token_chooser_parameters, dtype, device, tokenizer
364
        )
365
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
366
367
368
369
370
371
372

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
373

374
375
376
377
378
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

379
380
381
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
382
            slot_indices = torch.cat(slot_indices)
383
384
            if sliding_window is not None:
                prefill_cache_indices = torch.cat(prefill_cache_indices)
385
386
387
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
388
            slot_indices = slot_indices[0]
389
390
            if sliding_window is not None:
                prefill_cache_indices = prefill_cache_indices[0]
391

392
393
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
394
395
396
        )
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
397
398
399
        prefill_cache_indices = (
            prefill_cache_indices.to(device) if sliding_window is not None else None
        )
400
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
401
402
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
403
        )
404

drbh's avatar
drbh committed
405
406
407
408
409
        adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
        adapter_segments = torch.tensor(
            adapter_segments, dtype=torch.int32, device=device
        )

410
411
        if all_prefill_logprobs:
            prefill_head_indices = None
412
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
413
        elif no_prefill_logprobs:
414
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
415
416
417
418
419
420
421
422
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
Nicolas Patry's avatar
Nicolas Patry committed
423
424
425
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
426

427
428
429
430
431
432
433
434
        slots = torch.tensor(slots, dtype=torch.int64, device=device)
        block_tables_tensor = torch.zeros(
            (len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
        )
        for i, request_blocks in enumerate(block_tables):
            block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
        block_tables_tensor = block_tables_tensor.to(device)

435
436
437
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
438
            requests_idx_mapping=requests_idx_mapping,
439
440
            input_ids=input_ids,
            position_ids=position_ids,
441
            cu_seqlen_prefill=cu_seqlen_prefill,
442
            prefill_cache_indices=prefill_cache_indices,
443
444
            start_slots=start_slots,
            slot_indices=slot_indices,
445
446
447
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
448
            max_seqlen=max_seqlen,
449
450
451
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
452
            input_lengths=input_lengths,
453
            input_lengths_tensor=input_lengths_tensor,
454
455
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
456
            all_input_ids=all_input_ids,
457
458
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
459
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
460
461
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
462
            num_blocks=num_blocks,
463
            max_blocks=max_blocks,
drbh's avatar
drbh committed
464
465
466
467
468
469
            adapter_meta=AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            ),
Nicolas Patry's avatar
Nicolas Patry committed
470
            speculative_ids=None,
471
472
        )

473
474
475
476
477
478
479
480
481
482
483
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
        return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)

484
    @tracer.start_as_current_span("filter")
485
486
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
487
488
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
489
        if len(request_ids) == len(self):
490
491
            return self

492
        device = self.input_ids.device
493

494
495
496
        # New values after filtering
        requests_idx_mapping = {}

497
498
499
        # Used to index into tensors
        indices = []

500
501
502
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
503
504
        )

505
        # Create on CPU to only move to GPU once instead of at every copy
506
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
507
508
        max_seqlen = 0

509
        requests = []
510
511
        start_slots = []
        block_tables = []
512
513
        all_input_ids = []

514
        input_lengths = []
515
516
        prefix_offsets = []
        read_offsets = []
517

518
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
519
        top_n_tokens = []
drbh's avatar
drbh committed
520
        adapter_set = set()
521

522
        num_blocks = 0
523
524
525
526
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

527
528
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
529
            indices.append(idx)
530
531
532
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
533
534
535
536

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
537

538
539
540
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
541
542
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
543

544
545
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
546

Nicolas Patry's avatar
Nicolas Patry committed
547
548
            top_n_tokens.append(self.top_n_tokens[idx])

Nicolas Patry's avatar
Nicolas Patry committed
549
550
            ADAPTER_TO_INDEX = get_adapter_to_index()
            adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
drbh's avatar
drbh committed
551
552
            adapter_set.add(adapter_index)

553
            remaining_tokens = (
554
555
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
556

557
            request_block_table = self.block_tables[idx]
558
            num_blocks += len(request_block_table)
559
560
561
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

562
            # Copy to tensor (CPU)
563
            slot_indices[i] = cumulative_max_length + request_input_length - 1
564
565

            # Set slice
566
567
568
569
570
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
571
572
573
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
574

575
576
            max_blocks = max(max_blocks, len(request_block_table))

577
578
579
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
drbh's avatar
drbh committed
580
        adapter_indices = self.adapter_meta.adapter_indices[indices]
581
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
582
583
584
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
585
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
586
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
OlivierDehaene's avatar
OlivierDehaene committed
587
588
589
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
590
591

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
592

593
        # Move to GPU now that we have the whole tensor
594
        slot_indices = slot_indices.to(device)
595

drbh's avatar
drbh committed
596
597
598
599
600
        adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
        adapter_segments = torch.tensor(
            adapter_segments, dtype=torch.int32, device=device
        )

601
        return type(self)(
602
603
604
605
606
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
607
            cu_seqlen_prefill=None,
608
            prefill_cache_indices=None,
609
610
611
612
613
            start_slots=start_slots,
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
614
            max_seqlen=max_seqlen,
615
616
617
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
618
            input_lengths=input_lengths,
619
            input_lengths_tensor=input_lengths_tensor,
620
621
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
622
623
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
624
            next_token_chooser=next_token_chooser,
625
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
626
627
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
628
            num_blocks=num_blocks,
629
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
630
            speculative_ids=speculative_ids,
drbh's avatar
drbh committed
631
632
633
634
635
636
            adapter_meta=AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            ),
637
638
639
640
641
642
643
644
645
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

646
        num_blocks = 0
647
648
649
650
651
652
653
654
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
655
            num_blocks += b.num_blocks
OlivierDehaene's avatar
OlivierDehaene committed
656
657
658
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
659
660
661
662
663
664
665
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
666
                    + speculative_length
667
668
669
670
671
672
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
673
674
675

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
676
677
678
679
680
681
682
683
684
685
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
686
        )
Nicolas Patry's avatar
Nicolas Patry committed
687
688
689
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
drbh's avatar
drbh committed
690
691
692
693
694
695
696
697
        total_indices_size = sum(
            b.adapter_meta.adapter_indices.shape[0] for b in batches
        )
        adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
            total_indices_size
        )
        adapter_set = set()
        adapter_segment_builder = SegmentConcatBuilder()
698

699
700
        start_slots = []
        block_tables = []
701
702
703
        all_input_ids = []

        input_lengths = []
704
705
        prefix_offsets = []
        read_offsets = []
706

707
        next_token_chooser_parameters = []
708
        fsm_grammar_states = []
709
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
710
        top_n_tokens = []
711

712
        # Cumulative length
713
        cumulative_batch_size = 0
714
        cumulative_slots = 0
drbh's avatar
drbh committed
715
        cumulative_adapter_indices_size = 0
716
717
718

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
719
720
721
722
723
724
725
726

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

727
728
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
729
730
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
731
732
733
734

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
735
736
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
737
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
738
            slots[slots_start_index:slots_end_index] = batch.slots
739

drbh's avatar
drbh committed
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
            # Copy over adapter indices
            adapter_start_index = cumulative_adapter_indices_size
            adapter_end_index = (
                cumulative_adapter_indices_size
                + batch.adapter_meta.adapter_indices.shape[0]
            )
            adapter_indices[adapter_start_index:adapter_end_index] = (
                batch.adapter_meta.adapter_indices
            )
            cumulative_adapter_indices_size = adapter_end_index
            adapter_set.update(batch.adapter_meta.adapter_set)
            adapter_segment_builder.concat(
                batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
            )

755
756
757
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
758

759
760
761
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
762

763
764
765
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
766
767
            all_input_ids.extend(batch.all_input_ids)

768
            input_lengths.extend(batch.input_lengths)
769
770
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
771

772
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
773
            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
774
775
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
776
777
            top_n_tokens.extend(batch.top_n_tokens)

778
            # Update
779
            cumulative_batch_size += len(batch)
780
            cumulative_slots += len(batch.slots)
781

782
        start_slots = torch.concat(start_slots)
783

784
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
785
786
787
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
drbh's avatar
drbh committed
788
            tokenizer=batches[0].next_token_chooser.tokenizer,
789
            fsm_grammar_states=fsm_grammar_states,
790
791
        )

OlivierDehaene's avatar
OlivierDehaene committed
792
793
794
795
796
        speculative_ids = (
            torch.cat([b.speculative_ids for b in batches], dim=0)
            if batches[0].speculative_ids is not None
            else None
        )
Nicolas Patry's avatar
Nicolas Patry committed
797

drbh's avatar
drbh committed
798
799
        adapter_segments, adapter_segment_indices = adapter_segment_builder.build()

800
        return cls(
801
802
            batch_id=batches[0].batch_id,
            requests=requests,
803
            requests_idx_mapping=requests_idx_mapping,
804
805
            input_ids=input_ids,
            position_ids=position_ids,
806
            cu_seqlen_prefill=None,
807
            prefill_cache_indices=None,
808
809
810
811
812
            start_slots=start_slots,
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
813
            max_seqlen=max_seqlen,
814
815
816
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
817
            input_lengths=input_lengths,
818
            input_lengths_tensor=input_lengths_tensor,
819
820
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
821
822
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
823
            next_token_chooser=next_token_chooser,
824
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
825
826
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
827
            num_blocks=num_blocks,
828
            max_blocks=max_blocks,
OlivierDehaene's avatar
OlivierDehaene committed
829
            speculative_ids=speculative_ids,
drbh's avatar
drbh committed
830
831
832
833
834
835
            adapter_meta=AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            ),
836
837
838
839
840
841
        )

    def __len__(self):
        return len(self.requests)


842
843
844
845
846
847
848
849
850
851
852
853
ADAPTER_LAYERS = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}


854
855
856
class FlashCausalLM(Model):
    def __init__(
        self,
drbh's avatar
drbh committed
857
        model_id: str,
858
859
860
861
862
863
864
865
866
867
868
869
        model_class,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        speculator: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
        lora_adapter_ids: Optional[list] = [],
        tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
        config_class: PreTrainedTokenizerBase = AutoConfig,
        default_dtype=torch.float16,
        aliases=None,
        # Used for Santacoder override of config
870
871
872
        num_kv_heads: Optional[int] = None,
        # Deepseek V2 uses different QK and V dims.
        head_size: Optional[int] = None,
873
        skip_special_tokens: bool = True,
874
    ):
875
876
877
878
879
880
881
882
883
884
885
886
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = default_dtype if dtype is None else dtype
        elif SYSTEM == "ipex":
            if hasattr(torch, "xpu") and torch.xpu.is_available():
                device = torch.device(f"xpu:{rank}")
                dtype = default_dtype if dtype is None else dtype
            else:
                device = torch.device("cpu")
                # Float16 doesn't exist on target.
                dtype = torch.bfloat16 if dtype is None else dtype
887
                init_cpu_threads_env(rank_id=rank, world_size=world_size)
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
        else:
            raise NotImplementedError(f"{model_class} is only available on GPU")

        tokenizer = tokenizer_class.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )
        try:
            generation_config = GenerationConfig.from_pretrained(
                model_id, revision=revision, trust_remote_code=trust_remote_code
            )
            if isinstance(generation_config.eos_token_id, (list, set)):
                # TODO Huge hack
                tokenizer._eos_token_ids = set(generation_config.eos_token_id)
        except Exception:
            pass

        config = config_class.from_pretrained(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        config.quantize = quantize
        config.speculator = speculator

        torch.distributed.barrier(group=self.process_group)

916
        weights_loader = get_loader(quantize, model_id, revision)
917
918
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(
919
920
921
922
923
924
            filenames,
            device,
            dtype,
            process_group=self.process_group,
            aliases=aliases,
            weights_loader=weights_loader,
925
926
927
928
929
930
931
932
933
934
        )

        prefix = ""
        model = model_class(prefix, config, weights)
        torch.distributed.barrier(group=self.process_group)

        # VLM models define the config we care about in their text_config
        text_config = getattr(config, "text_config", None)
        if text_config is not None:
            config = text_config
935
936
937
938
939
940

        if getattr(config, "sliding_window", None) is not None:
            set_sliding_window(config.sliding_window)
        else:
            config.sliding_window = None

941
        self.num_layers = config.num_hidden_layers
942
        self.num_heads = config.num_attention_heads
943
944
        # Validation is done in the model itself
        if num_kv_heads is None:
945
946
            num_kv_heads = getattr(config, "num_key_value_heads", None)
            # GPT-2 workaround
947
            if num_kv_heads is None:
948
949
950
                num_kv_heads = getattr(config, "n_head", None)
        if num_kv_heads is None:
            raise ValueError("Cannot get the number of key/value heads")
951
952
953
954
955
956
        self.num_kv_heads = (
            num_kv_heads // self.process_group.size()
            if num_kv_heads > 1
            else num_kv_heads
        )
        assert self.num_kv_heads > 0
957
958

        if head_size is None:
Nicolas Patry's avatar
Nicolas Patry committed
959
960
961
962
963
964
            # Some models use GQA and different sizes for o_proj
            # and q_proj, that allows for that.
            if hasattr(config, "head_dim"):
                self.head_size = config.head_dim
            else:
                self.head_size = config.hidden_size // config.num_attention_heads
965
966
        else:
            self.head_size = head_size
967

968
        self.cuda_graphs = {}
969
        self.kv_cache = []
970

971
        if ATTENTION == "flashinfer":
972
973
974
975
976
977
978
979
980
981
982
983
984
985
            from text_generation_server.layers.attention.flash_infer import (
                create_prefill_state,
                create_decode_state,
            )

            self.prefill_state = create_prefill_state(device=device)

            if not CUDA_GRAPHS:
                self.decode_state = create_decode_state(
                    device=device,
                    num_heads=self.num_heads,
                    num_kv_heads=self.num_kv_heads,
                )

986
        super().__init__(
drbh's avatar
drbh committed
987
            model_id=model_id,
988
            model=model,
989
990
991
992
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
993
994
            rank=rank,
            world_size=world_size,
995
            sliding_window=config.sliding_window,
996
997
998
999
1000
1001
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
    def max_past(self) -> int:
        return getattr(self.model, "max_past", None)

    def init_kv_cache(
        self,
        num_blocks: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        self.kv_cache = []
        empty_cache()

        element_size = torch.tensor([], dtype=dtype).element_size()
Wang, Yi's avatar
Wang, Yi committed
1018
1019
1020
1021
        if SYSTEM == "ipex" and device.type == "xpu":
            x = 1
        else:
            x = BLOCK_SIZE // element_size
1022

1023
        if ATTENTION in {"flashdecoding", "flashinfer"}:
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
            self.kv_cache = [
                (
                    torch.empty(
                        (num_blocks, BLOCK_SIZE, num_heads, head_size),
                        dtype=dtype,
                        device=device,
                    ),
                    torch.empty(
                        (num_blocks, BLOCK_SIZE, num_heads, head_size),
                        dtype=dtype,
                        device=device,
                    ),
                )
                for _ in range(num_layers)
            ]
        elif SYSTEM == "ipex" and device == torch.device("cpu"):
Wang, Yi's avatar
Wang, Yi committed
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
            self.kv_cache = [
                (
                    torch.empty(
                        (num_blocks, num_heads, BLOCK_SIZE, head_size),
                        dtype=dtype,
                        device=device,
                    ),
                    torch.empty(
                        (num_blocks, num_heads, BLOCK_SIZE, head_size),
                        dtype=dtype,
                        device=device,
                    ),
                )
                for _ in range(num_layers)
            ]
        else:
            self.kv_cache = [
                (
                    torch.empty(
                        (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
                        dtype=dtype,
                        device=device,
                    ),
                    torch.empty(
                        (num_blocks, num_heads, head_size, BLOCK_SIZE),
                        dtype=dtype,
                        device=device,
                    ),
                )
                for _ in range(num_layers)
            ]
1071

1072
1073
1074
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
1075
        slots = torch.arange(bs, dtype=torch.int64, device=self.device)
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
        input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
        block_tables = (
            torch.arange(max_bt, dtype=torch.int32, device=self.device)
            .repeat(bs)
            .reshape((bs, max_bt))
        )

        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
1086
            "kv_cache": self.kv_cache,
1087
1088
1089
1090
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths,
        }
1091
        input_lengths_ = Seqlen(input_lengths=input_lengths)
1092
1093
1094
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs]["graph"] = graph

1095
        if ATTENTION == "flashinfer":
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
            from text_generation_server.layers.attention.flash_infer import (
                create_decode_state_cuda_graphs,
            )

            block_tables_ptr = torch.zeros(
                bs + 1, dtype=torch.int32, device=self.device
            )
            last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
            state = create_decode_state_cuda_graphs(
                device=input_ids.device,
                block_tables=block_tables.view(-1),
                block_tables_ptr=block_tables_ptr,
                last_page_len=last_page_len,
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
            )
            self.cuda_graphs[bs]["state"] = state
        else:
            state = None

1116
1117
        torch.cuda.synchronize()
        # Run once outside to warmup
1118
        with self._forward_context(
1119
            block_tables=block_tables,
1120
1121
1122
1123
1124
            cu_seqlen_prefill=None,
            input_lengths=input_lengths,
            state=state,
        ):
            self.model.forward(
1125
1126
1127
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
1128
                kv_cache=self.kv_cache,
1129
1130
                block_tables=block_tables,
                slots=slots,
1131
                input_lengths=input_lengths_,
1132
                max_s=max_s,
1133
                prefill_cache_indices=None,
1134
1135
                lm_head_indices=None,
            )
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154

            torch.cuda.synchronize()

            with torch.cuda.graph(graph, pool=MEM_POOL):
                input_lengths = Seqlen(input_lengths=input_lengths)
                logits, speculative_logits = self.model.forward(
                    input_ids=input_ids,
                    position_ids=position_ids,
                    cu_seqlen_prefill=None,
                    kv_cache=self.kv_cache,
                    block_tables=block_tables,
                    slots=slots,
                    input_lengths=input_lengths,
                    max_s=max_s,
                    prefill_cache_indices=None,
                    lm_head_indices=None,
                )
                self.cuda_graphs[bs]["logits"] = logits
                self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
1155
1156
        torch.cuda.synchronize()

1157
    def warmup(self, batch: FlashCausalLMBatch):
1158
        # The warmup batch is the biggest batch we could ever receive
Nicolas Patry's avatar
Nicolas Patry committed
1159
1160
        empty_cache()

1161
        try:
1162
1163
            self.init_kv_cache(
                batch.num_blocks,
1164
1165
1166
1167
1168
1169
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
                self.dtype,
                self.device,
            )
1170
            max_bt = batch.max_blocks
1171
            max_s = max_bt * BLOCK_SIZE
fxmarty's avatar
fxmarty committed
1172
1173
1174

            if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
                torch.cuda.tunable.tuning_enable(False)
1175
            _, batch, _ = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
1176
        except torch.cuda.OutOfMemoryError as e:
1177
            raise RuntimeError(
1178
1179
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
1180
            ) from e
1181

Nicolas Patry's avatar
Nicolas Patry committed
1182
        synchronize(self.device)
1183

1184
1185
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
1186
1187
1188
1189
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

Nicolas Patry's avatar
Nicolas Patry committed
1190
        free_memory = get_free_memory(self.device, MEMORY_FRACTION)
drbh's avatar
drbh committed
1191
        batch_num_blocks = batch.num_blocks if batch is not None else 0
1192
1193

        num_blocks = (
1194
1195
            # Leave 5% for some wiggle room
            int((free_memory * 0.95) // total_cache_size)
1196
            # Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
drbh's avatar
drbh committed
1197
            + batch_num_blocks
1198
1199
        )

1200
        del batch
1201

1202
        self.init_kv_cache(
1203
1204
1205
1206
1207
1208
1209
1210
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
            self.dtype,
            self.device,
        )

fxmarty's avatar
fxmarty committed
1211
1212
1213
1214
1215
        if SYSTEM == "rocm":
            if (
                os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
                or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
            ):
1216
1217
                torch.cuda.tunable.enable()

fxmarty's avatar
fxmarty committed
1218
1219
1220
1221
1222
1223
1224
1225
                if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
                    torch.cuda.tunable.tuning_enable(True)

                if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None:
                    tuning_sequences = [
                        int(val)
                        for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
                    ]
1226
                elif CUDA_GRAPHS is not None:
fxmarty's avatar
fxmarty committed
1227
                    tuning_sequences = CUDA_GRAPHS
1228
1229
1230
                else:
                    # For seqlen = 1, we dispatch to LLMM1 kernel.
                    tuning_sequences = [2, 3, 4, 5, 6, 7]
fxmarty's avatar
fxmarty committed
1231
1232
1233

                tunableop_filepath = os.path.join(
                    HUGGINGFACE_HUB_CACHE,
drbh's avatar
drbh committed
1234
                    f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
fxmarty's avatar
fxmarty committed
1235
1236
                )

1237
1238
1239
                log_master(
                    logger.info,
                    f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
fxmarty's avatar
fxmarty committed
1240
1241
1242
                )

                if os.path.isfile(tunableop_filepath):
1243
1244
1245
                    log_master(
                        logger.info,
                        f"The file {tunableop_filepath} already exists and will be reused.",
fxmarty's avatar
fxmarty committed
1246
1247
1248
1249
1250
1251
                    )
                    torch.cuda.tunable.read_file(tunableop_filepath)

                os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)

                for seqlen in tuning_sequences:
1252
                    log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
fxmarty's avatar
fxmarty committed
1253
1254
1255
1256
                    self.tunableop_warmup(seqlen)
                    torch.cuda.tunable.write_file(tunableop_filepath)
                torch.cuda.tunable.tuning_enable(False)
            else:
1257
1258
1259
                log_master(
                    logger.info,
                    "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
fxmarty's avatar
fxmarty committed
1260
1261
                )

1262
        if CUDA_GRAPHS:
1263
            try:
1264
1265
1266
                log_master(
                    logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
                )
1267
                # Warmup cuda graphs
1268
                for bs in CUDA_GRAPHS:
1269
1270
                    if self.speculate is None or self.speculate + 1 <= bs:
                        self.cuda_graph_warmup(bs, max_s, max_bt)
OlivierDehaene's avatar
OlivierDehaene committed
1271
            except torch.cuda.OutOfMemoryError:
1272
                logger.exception("Decode cuda graph warmup failed")
1273
        else:
1274
1275
1276
            log_master(
                logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
            )
1277

1278
        return int(num_blocks * BLOCK_SIZE)
1279

fxmarty's avatar
fxmarty committed
1280
1281
1282
1283
1284
    def tunableop_warmup(self, seqlen: int):
        input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
        slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)

fxmarty's avatar
fxmarty committed
1285
1286
        # Dummy value, some models (starcoder2) don't accept `None`.
        input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
1287
        input_lengths = Seqlen(input_lengths=input_lengths)
fxmarty's avatar
fxmarty committed
1288

fxmarty's avatar
fxmarty committed
1289
1290
1291
1292
1293
1294
1295
        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=torch.tensor(
                [0, seqlen], device=self.device, dtype=torch.int32
            ),
1296
            kv_cache=self.kv_cache,
fxmarty's avatar
fxmarty committed
1297
            block_tables=None,
fxmarty's avatar
fxmarty committed
1298
            input_lengths=input_lengths,
fxmarty's avatar
fxmarty committed
1299
1300
1301
            slots=slots,
            max_s=seqlen,
            lm_head_indices=None,
1302
            prefill_cache_indices=None,
fxmarty's avatar
fxmarty committed
1303
1304
        )

1305
    def forward(
drbh's avatar
drbh committed
1306
        self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
1307
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1308
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
1309
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
1310
1311
1312
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
1313
            kv_cache = self.kv_cache
OlivierDehaene's avatar
OlivierDehaene committed
1314
1315
1316
1317
1318
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
1319
1320
1321

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
1322
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
1323
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
1324
1325
1326
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1327
1328
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
1329
1330
1331
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1332
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
OlivierDehaene's avatar
OlivierDehaene committed
1333
1334
1335
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1336
1337

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
1338
1339
1340
1341
1342
1343
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
1344
1345
1346
1347
1348
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
1349
1350
1351
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
1352
            kv_cache = self.kv_cache
OlivierDehaene's avatar
OlivierDehaene committed
1353
1354
1355
1356
1357
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
1358

1359
1360
1361
1362
1363
1364
        if cu_seqlen_prefill is None and self.max_past() is not None:
            # In decode, not prefill, we're actually overwriting the KV-cache
            # in a circular buffer mode.
            # This makes sure the max_s for the decode pass is correct.
            max_s = min(self.max_past(), max_s)

1365
        bs = input_ids.shape[0]
OlivierDehaene's avatar
OlivierDehaene committed
1366
1367
1368
1369
1370
1371
1372
1373
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None

        if cu_seqlen_prefill is not None or cuda_graph is None:
1374
            with self._forward_context(
1375
                block_tables=block_tables,
1376
                cu_seqlen_prefill=cu_seqlen_prefill,
1377
                input_lengths=input_lengths,
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
            ):
                input_lengths = Seqlen(input_lengths=input_lengths)
                logits, speculative_logits = self.model.forward(
                    input_ids=input_ids,
                    position_ids=position_ids,
                    cu_seqlen_prefill=cu_seqlen_prefill,
                    kv_cache=kv_cache,
                    block_tables=block_tables,
                    slots=slots,
                    input_lengths=input_lengths,
                    max_s=max_s,
                    prefill_cache_indices=batch.prefill_cache_indices,
                    lm_head_indices=lm_head_indices,
                    adapter_data=adapter_data,
                )
                if batch.prefill_cache_indices is not None:
                    batch.prefill_cache_indices = None
                return logits, speculative_logits
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
        cuda_graph["block_tables"][
            : block_tables.shape[0], : block_tables.shape[1]
        ] = block_tables
        cuda_graph["slots"].fill_(-1)
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths

1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
        state = cuda_graph.get("state")
        with self._forward_context(
            block_tables=block_tables,
            cu_seqlen_prefill=None,
            input_lengths=input_lengths,
            state=state,
        ):
            # Replay the graph
            cuda_graph["graph"].replay()

1419
        # Slice output to the correct shape
1420
1421
1422
1423
1424
1425
1426
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
1427
1428
1429
1430

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
1431
1432
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
1433
        prefill = batch.cu_seqlen_prefill is not None
1434
        prefill_logprobs = batch.prefill_next_token_indices is not None
1435

drbh's avatar
drbh committed
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
        # Update adapter indices for speculative tokens (if present)
        adapter_meta = batch.adapter_meta
        if batch.speculative_ids is not None:
            B, speculative_length = batch.speculative_ids.shape
            new_length = speculative_length + 1
            adapter_indices = (
                adapter_meta.adapter_indices.unsqueeze(-1)
                .expand(B, new_length)
                .reshape(-1)
            )
            adapter_segments = adapter_meta.adapter_segments * new_length
            adapter_meta = AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_meta.adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_meta.segment_indices,
            )

        # Assign pointers to adapter weights
        # TODO(travis): don't update this if indices haven't changed
        adapter_data = AdapterBatchData.from_meta(
            adapter_meta,
            self.layer_to_adapter_weights,
            prefill,
            batch.prefill_head_indices,
        )

        out, speculative_logits = self.forward(batch, adapter_data)
1464

1465
1466
        if prefill:
            next_token_logits = (
1467
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
1468
            )
Nicolas Patry's avatar
Nicolas Patry committed
1469
1470
            if speculative_logits is not None:
                speculative_logits = (
OlivierDehaene's avatar
OlivierDehaene committed
1471
1472
1473
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
Nicolas Patry's avatar
Nicolas Patry committed
1474
                )
drbh's avatar
drbh committed
1475
1476
1477
1478
            next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
                len(batch)
            )

1479
1480
        else:
            next_token_logits = out
drbh's avatar
drbh committed
1481
            next_adapter_indices = batch.adapter_meta.adapter_indices
1482

Nicolas Patry's avatar
Nicolas Patry committed
1483
        speculate = get_speculate()
OlivierDehaene's avatar
OlivierDehaene committed
1484
1485
1486
1487
1488
1489
1490
1491
1492
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen],
            next_token_logits,
Nicolas Patry's avatar
Nicolas Patry committed
1493
            speculate,
OlivierDehaene's avatar
OlivierDehaene committed
1494
1495
            batch.speculative_ids,
            speculative_logits,
1496
1497
        )

Nicolas Patry's avatar
Nicolas Patry committed
1498
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
1499
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1500
1501
        )

1502
        if prefill:
1503
            if len(batch) > 1 and prefill_logprobs:
1504
1505
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
1506
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
1507
1508

            next_position_ids = batch.position_ids.new_empty(len(batch))
1509
1510
1511
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
1512
1513
1514
1515
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

1516
1517
1518
1519
1520
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
1521
        stopped = True
1522
1523

        # Zipped iterator
OlivierDehaene's avatar
OlivierDehaene committed
1524
        iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
1525

1526
1527
1528
1529
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

1530
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1531
        index = 0
OlivierDehaene's avatar
OlivierDehaene committed
1532
        for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
1533
            # Indexing metadata
1534
1535
1536
            start_index = cumulative_length
            end_index = cumulative_length + input_length

1537
            if prefill:
1538
1539
1540
1541
1542
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

1543
1544
1545
1546
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

drbh's avatar
drbh committed
1547
1548
1549
1550
1551
1552
                # Initialize adapter indices
                # In decode, we only have one token per row in the batch, so grab last index
                next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
                    end_index - 1
                ]

1553
1554
                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
1555
1556
                if prefill_logprobs:
                    if len(batch) > 1:
drbh's avatar
drbh committed
1557
1558
1559
                        prefill_tokens_indices[out_start_index : out_end_index - 1] = (
                            batch.input_ids[start_index + 1 : start_index + out_length]
                        )
1560
1561
1562
1563
1564
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
1565

Nicolas Patry's avatar
Nicolas Patry committed
1566
1567
1568
            for j in range(n_accepted_ids):
                batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
                index += 1
1569
1570
1571

            cumulative_length += input_length

drbh's avatar
drbh committed
1572
        # Update values
Nicolas Patry's avatar
Nicolas Patry committed
1573
1574
1575
1576
1577
        batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
        batch.speculative_ids = speculative_ids
        batch.position_ids = next_position_ids + accepted_ids
        batch.input_lengths_tensor += accepted_ids
        batch.slot_indices += accepted_ids
drbh's avatar
drbh committed
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
        batch.adapter_meta.adapter_indices = next_adapter_indices

        if prefill:
            # adjust segment lengths to account for all request lengths being 1 during decoding
            adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
            batch.adapter_meta.adapter_segments = torch.tensor(
                adapter_segments,
                dtype=torch.int32,
                device=batch.adapter_meta.adapter_segments.device,
            )
1588

1589
        if prefill and prefill_logprobs:
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
1600
        next_token_ids = next_input_ids.tolist()
1601
1602
        accepted_ids = accepted_ids.tolist()
        start_decode = time.time_ns()
1603
1604
1605
1606
1607

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
1608
1609
            batch.prefix_offsets,
            batch.read_offsets,
1610
1611
            batch.stopping_criterias,
            batch.all_input_ids,
1612
1613
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
1614
            batch.top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1615
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1616
1617
            batch_top_token_ids,
            batch_top_token_logprobs,
1618
1619
1620
        )

        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1621
        index = 0
1622
1623
1624
        for i, (
            request,
            input_length,
1625
1626
            prefix_offset,
            read_offset,
1627
1628
            stopping_criteria,
            all_input_ids,
1629
1630
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
1631
            top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1632
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1633
1634
            top_token_ids,
            top_token_logprobs,
1635
        ) in enumerate(iterator):
1636
            # Append next token to all tokens
Nicolas Patry's avatar
Nicolas Patry committed
1637
1638
1639
            next_token_texts = []
            left = 0

1640
            if n_accepted_ids > 1:
1641
                log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
1642

Nicolas Patry's avatar
Nicolas Patry committed
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
            current_stopped = False
            for j in range(index, index + n_accepted_ids):
                # Generated token
                next_token_id = next_token_ids[j]
                all_input_ids.append(next_token_id)
                next_token_text, prefix_offset, read_offset = self.decode_token(
                    all_input_ids,
                    prefix_offset,
                    read_offset,
                )
                next_token_texts.append(next_token_text)
1654

Nicolas Patry's avatar
Nicolas Patry committed
1655
1656
1657
1658
                stop, reason = stopping_criteria(
                    next_token_id,
                    next_token_text,
                )
1659

Nicolas Patry's avatar
Nicolas Patry committed
1660
1661
1662
1663
1664
1665
1666
                if stop:
                    left = index + n_accepted_ids - j - 1
                    current_stopped = True
                    break
                else:
                    current_stopped = False
            stopped = stopped and current_stopped
1667

OlivierDehaene's avatar
OlivierDehaene committed
1668
1669
1670
1671
            _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
            _next_token_logprobs = next_token_logprobs[
                index : index + n_accepted_ids - left
            ]
Nicolas Patry's avatar
Nicolas Patry committed
1672
            index += n_accepted_ids
1673

1674
1675
1676
1677
1678
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
1679
1680
                    output_text, _, _ = self.decode_token(
                        all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1681
1682
1683
1684
1685
1686
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
1687
1688
                    )
                    generated_text = GeneratedText(
1689
1690
1691
1692
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
1693
1694
1695
1696
1697
                    )
                else:
                    generated_text = None

                # Prefill
1698
1699
1700
1701
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

1702
1703
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
1704
                        out_start_index : out_end_index - 1
1705
1706
1707
1708
1709
1710
1711
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1712
1713

                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
1714
1715
1716
1717
                        prefill_token_ids,
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
1718
1719
1720
1721
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
1722
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
1723
                    all_top_tokens = []
drbh's avatar
drbh committed
1724
                    for top_token_ids, top_token_logprobs in zip(
1725
1726
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
1727
1728
1729
1730
1731
1732
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
1733
1734
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
1735
1736
1737
1738
1739
1740
1741
1742
1743
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
1744
1745
1746
                else:
                    top_tokens = None

1747
1748
1749
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1750
1751
1752
1753
1754
1755
                    Tokens(
                        _next_token_ids,
                        _next_token_logprobs,
                        next_token_texts,
                        [nid in self.all_special_ids for nid in _next_token_ids],
                    ),
1756
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
1757
                    top_tokens,
1758
1759
                )

1760
                generations.append(generation)
1761

drbh's avatar
drbh committed
1762
1763
1764
            # accept each new token for this specific request since we may
            # have more than one new token per request with speculative decoding
            for next_token_id in _next_token_ids:
OlivierDehaene's avatar
OlivierDehaene committed
1765
1766
1767
                batch.next_token_chooser = (
                    batch.next_token_chooser.advance_grammar_single(i, next_token_id)
                )
drbh's avatar
drbh committed
1768

1769
            # Update values
1770
            batch.input_lengths[i] = input_length + n_accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1771
1772
            if batch.input_lengths[i] > batch.max_seqlen:
                batch.max_seqlen = batch.input_lengths[i]
1773
1774
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1775
1776
            batch.all_input_ids[i] = all_input_ids

1777
1778
        if stopped:
            # No need to return a batch if we know that all requests stopped
1779
1780
1781
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
1782

1783
1784
1785
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1786

1787
1788
1789
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)
1790
1791
1792
1793
1794
1795
1796
1797
1798

    def _forward_context(
        self,
        *,
        block_tables: torch.Tensor,
        cu_seqlen_prefill: Optional[torch.Tensor],
        input_lengths: torch.Tensor,
        state: Optional[Any] = None,
    ) -> ContextManager:
1799
        if ATTENTION != "flashinfer":
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
            return nullcontext()

        from text_generation_server.layers.attention.flash_infer import (
            use_decode_state,
            use_prefill_state,
        )

        if cu_seqlen_prefill is not None:
            return use_prefill_state(
                state=state if state is not None else self.prefill_state,
                cu_seqlens=cu_seqlen_prefill,
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
            )
        else:
            assert input_lengths is not None
            return use_decode_state(
                state=state if state is not None else self.decode_state,
                input_lengths=input_lengths,
                block_tables=block_tables.view(-1),
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                page_size=BLOCK_SIZE,
            )