"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "80e2c4a8de3ad34af12f6127956975b69c1beaa7"
flash_causal_lm.py 74.4 KB
Newer Older
1
from contextlib import nullcontext
2
import math
3
import os
4
import time
5
6
7
import torch
import torch.distributed

8
9
import numpy as np

10
from loguru import logger
11
12
from dataclasses import dataclass
from opentelemetry import trace
13
14
15
16
17
18
from transformers import (
    PreTrainedTokenizerBase,
    AutoConfig,
    AutoTokenizer,
    GenerationConfig,
)
19
from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict
fxmarty's avatar
fxmarty committed
20

drbh's avatar
drbh committed
21
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
fxmarty's avatar
fxmarty committed
22
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
Daniël de Kok's avatar
Daniël de Kok committed
23
from text_generation_server.utils.chunks import concat_text_chunks
Nicolas Patry's avatar
Nicolas Patry committed
24
from text_generation_server.utils.import_utils import SYSTEM
OlivierDehaene's avatar
OlivierDehaene committed
25
from text_generation_server.models import Model
26
from text_generation_server.utils.log import log_master
27
from text_generation_server.utils.tokens import batch_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
28
from text_generation_server.utils.speculate import get_speculate
29
30
31
32
33
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
)
34
35
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
36
    Tokens,
37
38
39
40
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
Nicolas Patry's avatar
Nicolas Patry committed
41
42
from text_generation_server.models.globals import (
    MEM_POOL,
43
    ATTENTION,
44
    BLOCK_SIZE,
Nicolas Patry's avatar
Nicolas Patry committed
45
    CUDA_GRAPHS,
46
    TGI_WIGGLE_ROOM,
Nicolas Patry's avatar
Nicolas Patry committed
47
48
    get_adapter_to_index,
)
49
from text_generation_server.layers.attention import Seqlen
50
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
51
from text_generation_server.utils.dist import MEMORY_FRACTION
52
from text_generation_server.utils.quantization import get_loader
drbh's avatar
drbh committed
53
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
54

Nicolas Patry's avatar
Nicolas Patry committed
55
from text_generation_server.utils.import_utils import (
Nicolas Patry's avatar
Nicolas Patry committed
56
57
58
    empty_cache,
    synchronize,
    get_free_memory,
Nicolas Patry's avatar
Nicolas Patry committed
59
60
)

Nicolas Patry's avatar
Nicolas Patry committed
61
62
tracer = trace.get_tracer(__name__)

63
64
65
66
67
68
69
70
71
72
73
74
75
76

# Will be set in init
SLIDING_WINDOW: Optional[int] = None


def set_sliding_window(sliding_window: int):
    global SLIDING_WINDOW
    SLIDING_WINDOW = sliding_window


def get_sliding_windows() -> int:
    global SLIDING_WINDOW
    return SLIDING_WINDOW

77

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
def init_cpu_threads_env(rank_id: int, world_size: int):
    import importlib.util

    if importlib.util.find_spec("numa") is not None:
        import numa
        import psutil

        nodes = numa.get_max_node() + 1
        rank_per_node = math.ceil(world_size / nodes)
        num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes)
        node_id = int(rank_id / rank_per_node)
        rank_offset_per_node = rank_id % rank_per_node
        if os.getenv("OMP_NUM_THREADS") is None:
            num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1)
        else:
            num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS"))
        if len(numa.get_membind()) == nodes:
            numa.set_membind([node_id])
        torch.set_num_threads(num_cpus_per_rank)
        if len(numa.get_affinity(0)) == psutil.cpu_count(logical=True):
            cpu_start = num_cpus_per_rank * rank_offset_per_node
            numa.set_affinity(
                0,
                list(numa.node_to_cpus(node_id))[
                    cpu_start : cpu_start + num_cpus_per_rank
                ],
            )
        logger.info(f"affinity={numa.get_affinity(0)}, membind = {numa.get_membind()}")


108
109
110
111
@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
112
113
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
114
115

    # Decoder values
116
117
    input_ids: torch.Tensor
    position_ids: torch.Tensor
118
    speculative_ids: Optional[torch.Tensor]
119

120
121
122
123
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
124
125
126
    # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
    # as we only keep SLIDING_WINDOW values instead of the whole tensor
    prefill_cache_indices: Optional[torch.Tensor]
127
128
129
130
131
132
133
134
135
136

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor

    # list of length b of list of length s_i // block_size
137
    block_tables: List[List[int]]
138
    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
139
    block_tables_tensor: torch.Tensor
140
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
141
    slots: torch.Tensor
Nicolas Patry's avatar
Nicolas Patry committed
142
143
144
    # size [b], containing the number of blocks that can be retrieved from the cache
    prefix_lens: List[int]
    prefix_lens_tensor: torch.Tensor
145

146
147
    max_seqlen: int

148
149
150
151
152
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

Nicolas Patry's avatar
Nicolas Patry committed
153
154
155
    # Prefixes
    prefix_ids: List[List[int]]

156
157
    # All tokens
    all_input_ids: List[List[int]]
158
    all_input_ids_tensor: torch.Tensor
159
160
161

    # Lengths of all generations present in the batch
    input_lengths: List[int]
162
    input_lengths_tensor: torch.Tensor
163
164
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
165
166

    # Generation helpers
167
    next_token_chooser: HeterogeneousNextTokenChooser
168
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
169
170
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
171

drbh's avatar
drbh committed
172
173
174
    # Adapter metadata for each request
    adapter_meta: AdapterBatchMetadata

175
    # Number of blocks in this batch
176
    num_blocks: int
177
178
    # Maximum number of blocks
    max_blocks: int
179

180
181
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
182
            id=self.batch_id,
183
            request_ids=[r.id for r in self.requests],
184
            size=len(self),
185
            max_tokens=self.num_blocks * BLOCK_SIZE,
186
187
188
        )

    @classmethod
Daniël de Kok's avatar
Daniël de Kok committed
189
190
191
    def batch_tokenized_inputs(
        cls, requests: Iterable[generate_pb2.Request], tokenizer
    ):
192
193
194
        max_length = 0
        all_input_ids = []
        batch_size = 0
195
        for r in requests:
196
197
198
199
200
201
202
203
204
205
206
            batch_size += 1
            inputs = concat_text_chunks(r.input_chunks.chunks)
            input_ids = tokenizer(
                inputs,
                truncation=True,
                max_length=r.truncate,
                add_special_tokens=r.add_special_tokens,
            )["input_ids"]
            max_length = max(max_length, len(input_ids))
            all_input_ids.append(input_ids)
        return all_input_ids
207

drbh's avatar
drbh committed
208
209
210
211
212
213
214
215
216
    @classmethod
    def from_tokenized(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        batch_tokenized_inputs,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
217
        sliding_window = get_sliding_windows()
218
        position_ids = []
219
        cu_seqlen_prefill = [0]
220
221
        start_slots = []
        slot_indices = []
222
        prefill_cache_indices = []
223
224

        input_lengths = []
225
226
        prefix_offsets = []
        read_offsets = []
227
        all_input_ids = []
Nicolas Patry's avatar
Nicolas Patry committed
228
        prefix_ids = []
229
        requests_idx_mapping = {}
230

231
232
233
234
235
236
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

237
        next_token_chooser_parameters = []
238
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
239
        top_n_tokens = []
240

drbh's avatar
drbh committed
241
242
243
        adapter_indices_list = []
        adapter_set = set()

244
245
        # Cumulative length
        cumulative_length = 0
Nicolas Patry's avatar
Nicolas Patry committed
246
        cumulative_slot_tokens = 0
247
        prefill_out_cumulative_length = 0
248

249
        num_blocks = 0
250
        max_seqlen = 0
251
        max_length = 0
252
        max_blocks = 0
253

254
255
        block_tables = []
        slots = []
Nicolas Patry's avatar
Nicolas Patry committed
256
        prefix_lens = []
257

258
        # Parse batch
259
260
261
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
262
263
264
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

Nicolas Patry's avatar
Nicolas Patry committed
265
266
            orig_input_length = len(tokenized_input)

267
268
269
270
271
272
273
            prefix_len = r.prefix_len
            assert (
                prefix_len <= orig_input_length
            ), f"Prefix {prefix_len} vs input {orig_input_length}"
            if prefix_len == orig_input_length:
                assert prefix_len > 0
                prefix_len -= 1
Nicolas Patry's avatar
Nicolas Patry committed
274
275
276
277

            prefix_ids.append(tokenized_input[:prefix_len])
            tokenized_input = tokenized_input[prefix_len:]

278
279
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
280

281
            prefix_offsets.append(input_length - 5)
282
            read_offsets.append(input_length)
283

284
            all_input_ids.append(tokenized_input)
285
286

            # Position ids
Nicolas Patry's avatar
Nicolas Patry committed
287
288
289
            request_position_ids = torch.arange(
                prefix_len, orig_input_length, dtype=torch.int32
            )
290
            position_ids.append(request_position_ids)
291
292

            # Add cumulative lengths of all previous inputs
293
            cu_seqlen_prefill.append(cumulative_length + input_length)
294

295
            next_token_chooser_parameters.append(r.parameters)
296

297
298
299
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
300
            max_new_tokens = stopping_criteria.max_new_tokens
301
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
302
            top_n_tokens.append(r.top_n_tokens)
303

Nicolas Patry's avatar
Nicolas Patry committed
304
305
            ADAPTER_TO_INDEX = get_adapter_to_index()
            adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
drbh's avatar
drbh committed
306
307
308
            adapter_indices_list.append(torch.full((input_length,), adapter_index))
            adapter_set.add(adapter_index)

309
310
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
311
            speculative_length = get_speculate()
drbh's avatar
drbh committed
312
            speculative_length = 0 if speculative_length is None else speculative_length
Nicolas Patry's avatar
Nicolas Patry committed
313
314
315
316
317
318
319

            # Tokens that need to be mapped to blocks.
            block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length

            # Tokens that need to be mapped to slots. We don't need slots for the
            # cached prefix (if present).
            slot_tokens = input_length + max_new_tokens - 1 + speculative_length
320
321
322

            # blocks and slots can be empty (for example in warmup)
            if not r.blocks:
Nicolas Patry's avatar
Nicolas Patry committed
323
                needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
324
325
326
327
328
329
330
331
332
333
                request_blocks = [
                    b for b in range(num_blocks, num_blocks + needed_blocks)
                ]
                request_slots = [
                    s
                    for b in request_blocks
                    for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
                ]
            else:
                request_blocks = r.blocks
Nicolas Patry's avatar
Nicolas Patry committed
334
335
336
                request_slots = r.slots[
                    prefix_len:  #: orig_input_length + max_new_tokens + speculative_length
                ]
337
338

            block_tables.append(request_blocks)
Nicolas Patry's avatar
Nicolas Patry committed
339
340
341

            slots.extend(request_slots)
            prefix_lens.append(prefix_len)
342
            num_blocks += len(request_blocks)
Nicolas Patry's avatar
Nicolas Patry committed
343
            start_slots.append(cumulative_slot_tokens)
344
345

            request_slot_indices = torch.arange(
Nicolas Patry's avatar
Nicolas Patry committed
346
347
                cumulative_slot_tokens,
                cumulative_slot_tokens + input_length,
348
349
350
351
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

352
353
354
355
356
357
358
359
360
            # Create tensor to slice into the kv tensor in prefill
            if sliding_window is not None:
                request_prefill_cache_indices = torch.arange(
                    cumulative_length + max(0, input_length - sliding_window),
                    cumulative_length + input_length,
                    dtype=torch.int64,
                )
                prefill_cache_indices.append(request_prefill_cache_indices)

361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

381
382
            # Update
            cumulative_length += input_length
Nicolas Patry's avatar
Nicolas Patry committed
383
            cumulative_slot_tokens += slot_tokens
384
            max_seqlen = max(max_seqlen, input_length)
385
            max_blocks = max(max_blocks, len(request_blocks))
OlivierDehaene's avatar
OlivierDehaene committed
386
387
388
            max_length = max(
                max_length, input_length + max_new_tokens + speculative_length
            )
389

drbh's avatar
drbh committed
390
391
392
393
        adapter_indices = torch.cat(adapter_indices_list).to(
            dtype=torch.int64, device=device
        )

394
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
395
            next_token_chooser_parameters, dtype, device, tokenizer
396
        )
397
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
398
399
400
401
402
403
404

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
405

406
407
408
409
410
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

411
412
413
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
414
            slot_indices = torch.cat(slot_indices)
415
416
            if sliding_window is not None:
                prefill_cache_indices = torch.cat(prefill_cache_indices)
417
418
419
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
420
            slot_indices = slot_indices[0]
421
422
            if sliding_window is not None:
                prefill_cache_indices = prefill_cache_indices[0]
423

424
425
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
426
427
428
        )
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
429
430
431
        prefill_cache_indices = (
            prefill_cache_indices.to(device) if sliding_window is not None else None
        )
432
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
433
434
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
435
        )
436

drbh's avatar
drbh committed
437
438
439
440
441
        adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
        adapter_segments = torch.tensor(
            adapter_segments, dtype=torch.int32, device=device
        )

442
443
        if all_prefill_logprobs:
            prefill_head_indices = None
444
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
445
        elif no_prefill_logprobs:
446
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
447
448
449
450
451
452
453
454
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
Nicolas Patry's avatar
Nicolas Patry committed
455
456
457
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
458

459
        slots = torch.tensor(slots, dtype=torch.int64, device=device)
Nicolas Patry's avatar
Nicolas Patry committed
460

461
462
463
464
465
466
        block_tables_tensor = torch.zeros(
            (len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
        )
        for i, request_blocks in enumerate(block_tables):
            block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
        block_tables_tensor = block_tables_tensor.to(device)
Nicolas Patry's avatar
Nicolas Patry committed
467
        prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
468

469
470
471
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
472
            requests_idx_mapping=requests_idx_mapping,
473
474
            input_ids=input_ids,
            position_ids=position_ids,
475
            cu_seqlen_prefill=cu_seqlen_prefill,
476
            prefill_cache_indices=prefill_cache_indices,
477
478
            start_slots=start_slots,
            slot_indices=slot_indices,
479
480
481
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
Nicolas Patry's avatar
Nicolas Patry committed
482
483
            prefix_lens=prefix_lens,
            prefix_lens_tensor=prefix_lens_tensor,
484
            max_seqlen=max_seqlen,
485
486
487
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
488
            input_lengths=input_lengths,
489
            input_lengths_tensor=input_lengths_tensor,
490
491
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
492
            all_input_ids=all_input_ids,
493
            all_input_ids_tensor=all_input_ids_tensor,
Nicolas Patry's avatar
Nicolas Patry committed
494
            prefix_ids=prefix_ids,
495
            next_token_chooser=next_token_chooser,
496
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
497
498
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
499
            num_blocks=num_blocks,
500
            max_blocks=max_blocks,
drbh's avatar
drbh committed
501
502
503
504
505
506
            adapter_meta=AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            ),
Nicolas Patry's avatar
Nicolas Patry committed
507
            speculative_ids=None,
508
509
        )

510
511
512
513
514
515
516
517
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
518
        assert len(pb.requests) > 0
519
520
521
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
        return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)

522
    @tracer.start_as_current_span("filter")
523
524
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
525
526
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
527
        if len(request_ids) == len(self):
528
529
            return self

530
        device = self.input_ids.device
531

532
533
534
        # New values after filtering
        requests_idx_mapping = {}

535
536
537
        # Used to index into tensors
        indices = []

538
539
540
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
541
542
        )

543
        # Create on CPU to only move to GPU once instead of at every copy
544
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
545
546
        max_seqlen = 0

547
        requests = []
548
549
        start_slots = []
        block_tables = []
550
        all_input_ids = []
Nicolas Patry's avatar
Nicolas Patry committed
551
        prefix_ids = []
552

553
        input_lengths = []
Nicolas Patry's avatar
Nicolas Patry committed
554
        prefix_lens = []
555
556
        prefix_offsets = []
        read_offsets = []
557

558
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
559
        top_n_tokens = []
drbh's avatar
drbh committed
560
        adapter_set = set()
561

562
        num_blocks = 0
563
564
565
566
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

567
568
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
569
            indices.append(idx)
570
571
572
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
573
574
575

            # Get length
            request_input_length = self.input_lengths[idx]
Nicolas Patry's avatar
Nicolas Patry committed
576
            prefix_len = self.prefix_lens[idx]
577
            max_seqlen = max(max_seqlen, request_input_length)
578

579
            all_input_ids.append(self.all_input_ids[idx])
Nicolas Patry's avatar
Nicolas Patry committed
580
            prefix_ids.append(self.prefix_ids[idx])
581
582

            input_lengths.append(request_input_length)
Nicolas Patry's avatar
Nicolas Patry committed
583
            prefix_lens.append(prefix_len)
584
585
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
586

587
588
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
589

Nicolas Patry's avatar
Nicolas Patry committed
590
591
            top_n_tokens.append(self.top_n_tokens[idx])

Nicolas Patry's avatar
Nicolas Patry committed
592
593
            ADAPTER_TO_INDEX = get_adapter_to_index()
            adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
drbh's avatar
drbh committed
594
595
            adapter_set.add(adapter_index)

596
            remaining_tokens = (
597
598
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
599

600
            request_block_table = self.block_tables[idx]
601
            num_blocks += len(request_block_table)
602
603
604
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

605
            # Copy to tensor (CPU)
606
            slot_indices[i] = cumulative_max_length + request_input_length - 1
607
608

            # Set slice
609
610
611
612
613
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
614
615
616
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
617

618
619
            max_blocks = max(max_blocks, len(request_block_table))

620
621
622
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
drbh's avatar
drbh committed
623
        adapter_indices = self.adapter_meta.adapter_indices[indices]
624
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
625
626
627
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
Nicolas Patry's avatar
Nicolas Patry committed
628
        prefix_lens_tensor = self.prefix_lens_tensor[indices]
629
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
630
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
OlivierDehaene's avatar
OlivierDehaene committed
631
632
633
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
634
635

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
636

637
        # Move to GPU now that we have the whole tensor
638
        slot_indices = slot_indices.to(device)
639

drbh's avatar
drbh committed
640
641
642
643
        adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
        adapter_segments = torch.tensor(
            adapter_segments, dtype=torch.int32, device=device
        )
644
        # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
drbh's avatar
drbh committed
645

646
        return type(self)(
647
648
649
650
651
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
652
            cu_seqlen_prefill=None,
653
            prefill_cache_indices=None,
654
655
656
657
658
            start_slots=start_slots,
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
659
            max_seqlen=max_seqlen,
660
661
662
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
663
            input_lengths=input_lengths,
664
            input_lengths_tensor=input_lengths_tensor,
Nicolas Patry's avatar
Nicolas Patry committed
665
666
            prefix_lens=prefix_lens,
            prefix_lens_tensor=prefix_lens_tensor,
667
668
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
669
670
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
Nicolas Patry's avatar
Nicolas Patry committed
671
            prefix_ids=prefix_ids,
672
            next_token_chooser=next_token_chooser,
673
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
674
675
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
676
            num_blocks=num_blocks,
677
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
678
            speculative_ids=speculative_ids,
drbh's avatar
drbh committed
679
680
681
682
683
684
            adapter_meta=AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            ),
685
686
687
688
689
690
691
692
693
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

694
        num_blocks = 0
695
696
697
698
699
700
701
702
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
703
            num_blocks += b.num_blocks
OlivierDehaene's avatar
OlivierDehaene committed
704
705
706
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
707
708
709
710
711
712
713
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
714
                    + speculative_length
715
716
717
718
719
720
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
721
722
723

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
724
725
726
727
728
729
730
731
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
Nicolas Patry's avatar
Nicolas Patry committed
732
        prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size)
733
734
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
735
        )
Nicolas Patry's avatar
Nicolas Patry committed
736
737
738
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
drbh's avatar
drbh committed
739
740
741
742
743
744
745
746
        total_indices_size = sum(
            b.adapter_meta.adapter_indices.shape[0] for b in batches
        )
        adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
            total_indices_size
        )
        adapter_set = set()
        adapter_segment_builder = SegmentConcatBuilder()
747

748
749
        start_slots = []
        block_tables = []
Nicolas Patry's avatar
Nicolas Patry committed
750
        prefix_lens = []
751
        all_input_ids = []
Nicolas Patry's avatar
Nicolas Patry committed
752
        prefix_ids = []
753
754

        input_lengths = []
755
756
        prefix_offsets = []
        read_offsets = []
757

758
        next_token_chooser_parameters = []
759
        fsm_grammar_states = []
760
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
761
        top_n_tokens = []
762

763
        # Cumulative length
764
        cumulative_batch_size = 0
765
        cumulative_slots = 0
drbh's avatar
drbh committed
766
        cumulative_adapter_indices_size = 0
767
768
769

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
770
771
772
773
774
775
776
777

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

778
779
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
780
781
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
782
783
784
785

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
786
787
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
788
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
789
            slots[slots_start_index:slots_end_index] = batch.slots
790

drbh's avatar
drbh committed
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
            # Copy over adapter indices
            adapter_start_index = cumulative_adapter_indices_size
            adapter_end_index = (
                cumulative_adapter_indices_size
                + batch.adapter_meta.adapter_indices.shape[0]
            )
            adapter_indices[adapter_start_index:adapter_end_index] = (
                batch.adapter_meta.adapter_indices
            )
            cumulative_adapter_indices_size = adapter_end_index
            adapter_set.update(batch.adapter_meta.adapter_set)
            adapter_segment_builder.concat(
                batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
            )

806
807
808
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
809

810
811
812
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
813

Nicolas Patry's avatar
Nicolas Patry committed
814
815
            prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor

816
817
818
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
Nicolas Patry's avatar
Nicolas Patry committed
819
            prefix_lens.extend(batch.prefix_lens)
820
            all_input_ids.extend(batch.all_input_ids)
Nicolas Patry's avatar
Nicolas Patry committed
821
            prefix_ids.extend(batch.prefix_ids)
822

823
            input_lengths.extend(batch.input_lengths)
824
825
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
826

827
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
828
            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
829
830
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
831
832
            top_n_tokens.extend(batch.top_n_tokens)

833
            # Update
834
            cumulative_batch_size += len(batch)
835
            cumulative_slots += len(batch.slots)
836

837
        start_slots = torch.concat(start_slots)
838

839
840
        # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()

841
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
842
843
844
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
drbh's avatar
drbh committed
845
            tokenizer=batches[0].next_token_chooser.tokenizer,
846
            fsm_grammar_states=fsm_grammar_states,
847
848
        )

OlivierDehaene's avatar
OlivierDehaene committed
849
850
851
852
853
        speculative_ids = (
            torch.cat([b.speculative_ids for b in batches], dim=0)
            if batches[0].speculative_ids is not None
            else None
        )
Nicolas Patry's avatar
Nicolas Patry committed
854

drbh's avatar
drbh committed
855
856
        adapter_segments, adapter_segment_indices = adapter_segment_builder.build()

857
        return cls(
858
859
            batch_id=batches[0].batch_id,
            requests=requests,
860
            requests_idx_mapping=requests_idx_mapping,
861
862
            input_ids=input_ids,
            position_ids=position_ids,
863
            cu_seqlen_prefill=None,
864
            prefill_cache_indices=None,
865
866
867
868
            start_slots=start_slots,
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
Nicolas Patry's avatar
Nicolas Patry committed
869
870
            prefix_lens=prefix_lens,
            prefix_lens_tensor=prefix_lens_tensor,
871
            slots=slots,
872
            max_seqlen=max_seqlen,
873
874
875
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
876
            input_lengths=input_lengths,
877
            input_lengths_tensor=input_lengths_tensor,
878
879
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
880
881
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
Nicolas Patry's avatar
Nicolas Patry committed
882
            prefix_ids=prefix_ids,
883
            next_token_chooser=next_token_chooser,
884
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
885
886
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
887
            num_blocks=num_blocks,
888
            max_blocks=max_blocks,
OlivierDehaene's avatar
OlivierDehaene committed
889
            speculative_ids=speculative_ids,
drbh's avatar
drbh committed
890
891
892
893
894
895
            adapter_meta=AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            ),
896
897
898
899
900
901
        )

    def __len__(self):
        return len(self.requests)


902
903
904
905
906
907
908
909
910
911
912
913
ADAPTER_LAYERS = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}


914
915
916
class FlashCausalLM(Model):
    def __init__(
        self,
drbh's avatar
drbh committed
917
        model_id: str,
918
919
920
921
922
923
924
925
926
927
928
929
        model_class,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        speculator: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
        lora_adapter_ids: Optional[list] = [],
        tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
        config_class: PreTrainedTokenizerBase = AutoConfig,
        default_dtype=torch.float16,
        aliases=None,
        # Used for Santacoder override of config
930
931
932
        num_kv_heads: Optional[int] = None,
        # Deepseek V2 uses different QK and V dims.
        head_size: Optional[int] = None,
933
        skip_special_tokens: bool = True,
934
    ):
Nicolas Patry's avatar
Nicolas Patry committed
935
        self.quantize = quantize
936
937
938
939
940
941
942
943
944
945
946
947
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = default_dtype if dtype is None else dtype
        elif SYSTEM == "ipex":
            if hasattr(torch, "xpu") and torch.xpu.is_available():
                device = torch.device(f"xpu:{rank}")
                dtype = default_dtype if dtype is None else dtype
            else:
                device = torch.device("cpu")
                # Float16 doesn't exist on target.
                dtype = torch.bfloat16 if dtype is None else dtype
948
                init_cpu_threads_env(rank_id=rank, world_size=world_size)
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
        else:
            raise NotImplementedError(f"{model_class} is only available on GPU")

        tokenizer = tokenizer_class.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )
        try:
            generation_config = GenerationConfig.from_pretrained(
                model_id, revision=revision, trust_remote_code=trust_remote_code
            )
            if isinstance(generation_config.eos_token_id, (list, set)):
                # TODO Huge hack
                tokenizer._eos_token_ids = set(generation_config.eos_token_id)
        except Exception:
            pass

        config = config_class.from_pretrained(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        config.quantize = quantize
        config.speculator = speculator

        torch.distributed.barrier(group=self.process_group)

977
        weights_loader = get_loader(quantize, model_id, revision)
978
979
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(
980
981
982
983
984
985
            filenames,
            device,
            dtype,
            process_group=self.process_group,
            aliases=aliases,
            weights_loader=weights_loader,
986
987
988
989
990
991
992
993
994
995
        )

        prefix = ""
        model = model_class(prefix, config, weights)
        torch.distributed.barrier(group=self.process_group)

        # VLM models define the config we care about in their text_config
        text_config = getattr(config, "text_config", None)
        if text_config is not None:
            config = text_config
996
997
998
999
1000
1001

        if getattr(config, "sliding_window", None) is not None:
            set_sliding_window(config.sliding_window)
        else:
            config.sliding_window = None

1002
        self.num_layers = config.num_hidden_layers
1003
        self.num_heads = config.num_attention_heads // self.process_group.size()
1004
1005
        # Validation is done in the model itself
        if num_kv_heads is None:
1006
1007
            num_kv_heads = getattr(config, "num_key_value_heads", None)
            # GPT-2 workaround
1008
            if num_kv_heads is None:
1009
1010
1011
                num_kv_heads = getattr(config, "n_head", None)
        if num_kv_heads is None:
            raise ValueError("Cannot get the number of key/value heads")
1012
1013
1014
1015
1016
1017
        self.num_kv_heads = (
            num_kv_heads // self.process_group.size()
            if num_kv_heads > 1
            else num_kv_heads
        )
        assert self.num_kv_heads > 0
1018
1019

        if head_size is None:
Nicolas Patry's avatar
Nicolas Patry committed
1020
1021
1022
1023
1024
1025
            # Some models use GQA and different sizes for o_proj
            # and q_proj, that allows for that.
            if hasattr(config, "head_dim"):
                self.head_size = config.head_dim
            else:
                self.head_size = config.hidden_size // config.num_attention_heads
1026
1027
        else:
            self.head_size = head_size
1028

1029
        self.cuda_graphs = {}
1030
        self.kv_cache = []
1031

1032
        if ATTENTION == "flashinfer":
Nicolas Patry's avatar
Nicolas Patry committed
1033
            from text_generation_server.layers.attention.flashinfer import (
1034
1035
                create_prefill_state,
                create_decode_state,
Nicolas Patry's avatar
Nicolas Patry committed
1036
                create_prefill_with_paged_kv_state,
1037
1038
1039
            )

            self.prefill_state = create_prefill_state(device=device)
Nicolas Patry's avatar
Nicolas Patry committed
1040
1041
1042
            self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
                device=device
            )
1043

Nicolas Patry's avatar
Nicolas Patry committed
1044
1045
1046
1047
1048
            self.decode_state = create_decode_state(
                device=device,
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
            )
1049

1050
        super().__init__(
drbh's avatar
drbh committed
1051
            model_id=model_id,
1052
            model=model,
1053
1054
1055
1056
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
1057
1058
            rank=rank,
            world_size=world_size,
1059
            sliding_window=config.sliding_window,
1060
1061
1062
1063
1064
1065
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
    def max_past(self) -> int:
        return getattr(self.model, "max_past", None)

    def init_kv_cache(
        self,
        num_blocks: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        self.kv_cache = []
        empty_cache()

        element_size = torch.tensor([], dtype=dtype).element_size()
Wang, Yi's avatar
Wang, Yi committed
1082
1083
1084
1085
        if SYSTEM == "ipex" and device.type == "xpu":
            x = 1
        else:
            x = BLOCK_SIZE // element_size
1086

1087
        if ATTENTION in {"flashdecoding", "flashinfer"}:
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
            self.kv_cache = [
                (
                    torch.empty(
                        (num_blocks, BLOCK_SIZE, num_heads, head_size),
                        dtype=dtype,
                        device=device,
                    ),
                    torch.empty(
                        (num_blocks, BLOCK_SIZE, num_heads, head_size),
                        dtype=dtype,
                        device=device,
                    ),
                )
                for _ in range(num_layers)
            ]
        elif SYSTEM == "ipex" and device == torch.device("cpu"):
Wang, Yi's avatar
Wang, Yi committed
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
            self.kv_cache = [
                (
                    torch.empty(
                        (num_blocks, num_heads, BLOCK_SIZE, head_size),
                        dtype=dtype,
                        device=device,
                    ),
                    torch.empty(
                        (num_blocks, num_heads, BLOCK_SIZE, head_size),
                        dtype=dtype,
                        device=device,
                    ),
                )
                for _ in range(num_layers)
            ]
        else:
            self.kv_cache = [
                (
                    torch.empty(
                        (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
                        dtype=dtype,
                        device=device,
                    ),
                    torch.empty(
                        (num_blocks, num_heads, head_size, BLOCK_SIZE),
                        dtype=dtype,
                        device=device,
                    ),
                )
                for _ in range(num_layers)
            ]
1135

1136
1137
1138
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
1139
        slots = torch.arange(bs, dtype=torch.int64, device=self.device)
Nicolas Patry's avatar
Nicolas Patry committed
1140
1141
1142
1143
        input_lengths = [max_s] * bs
        prefix_lengths = [0] * bs
        input_lengths_tensor = (
            torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
1144
        )
Nicolas Patry's avatar
Nicolas Patry committed
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
        prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
        block_tables = torch.arange(
            max_bt, dtype=torch.int32, device=self.device
        ).repeat(bs)
        block_tables = block_tables.reshape((bs, max_bt))

        if ATTENTION == "flashinfer":
            block_tables = block_tables_to_ragged(
                block_tables=block_tables,
                input_lengths=input_lengths,
                prefix_lens=prefix_lengths,
            )
            from text_generation_server.layers.attention.flashinfer import (
1158
1159
1160
1161
1162
1163
1164
1165
1166
                create_decode_state_cuda_graphs,
            )

            block_tables_ptr = torch.zeros(
                bs + 1, dtype=torch.int32, device=self.device
            )
            last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
            state = create_decode_state_cuda_graphs(
                device=input_ids.device,
Nicolas Patry's avatar
Nicolas Patry committed
1167
                block_tables=block_tables,
1168
1169
1170
1171
1172
1173
1174
1175
                block_tables_ptr=block_tables_ptr,
                last_page_len=last_page_len,
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
            )
        else:
            state = None

1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "kv_cache": self.kv_cache,
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths_tensor,
            "prefix_lengths": prefix_lengths_tensor,
            "state": state,
            "graph": graph,
        }

1189
1190
        torch.cuda.synchronize()
        # Run once outside to warmup
1191
        with self._forward_context(
1192
            block_tables=block_tables,
1193
            cu_seqlen_prefill=None,
Nicolas Patry's avatar
Nicolas Patry committed
1194
            input_lengths_tensor=input_lengths_tensor,
1195
            state=state,
Nicolas Patry's avatar
Nicolas Patry committed
1196
            prefix_lens_tensor=prefix_lengths_tensor,
1197
        ):
1198
1199
1200
1201
1202
1203
1204
            seqlen = Seqlen(
                input_lengths=input_lengths_tensor,
                prefix_lengths=prefix_lengths_tensor,
                cu_seqlen_q=None,
                max_q=1,
                max_k=max_s,
            )
1205
            self.model.forward(
1206
1207
1208
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
1209
                kv_cache=self.kv_cache,
1210
1211
                block_tables=block_tables,
                slots=slots,
1212
                seqlen=seqlen,
1213
                max_s=max_s,
1214
                prefill_cache_indices=None,
1215
1216
                lm_head_indices=None,
            )
1217
            del seqlen
1218
1219
1220
1221

            torch.cuda.synchronize()

            with torch.cuda.graph(graph, pool=MEM_POOL):
1222
1223
1224
1225
1226
1227
1228
                seqlen = Seqlen(
                    input_lengths=input_lengths_tensor,
                    prefix_lengths=prefix_lengths_tensor,
                    cu_seqlen_q=None,
                    max_q=1,
                    max_k=max_s,
                )
1229
1230
1231
1232
1233
1234
1235
                logits, speculative_logits = self.model.forward(
                    input_ids=input_ids,
                    position_ids=position_ids,
                    cu_seqlen_prefill=None,
                    kv_cache=self.kv_cache,
                    block_tables=block_tables,
                    slots=slots,
1236
                    seqlen=seqlen,
1237
1238
1239
1240
1241
1242
                    max_s=max_s,
                    prefill_cache_indices=None,
                    lm_head_indices=None,
                )
                self.cuda_graphs[bs]["logits"] = logits
                self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
1243
1244
        torch.cuda.synchronize()

1245
    def warmup(self, batch: FlashCausalLMBatch):
1246
        # The warmup batch is the biggest batch we could ever receive
Nicolas Patry's avatar
Nicolas Patry committed
1247
1248
        empty_cache()

1249
        try:
1250
1251
            self.init_kv_cache(
                batch.num_blocks,
1252
1253
1254
1255
1256
1257
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
                self.dtype,
                self.device,
            )
1258
            max_bt = batch.max_blocks
1259
            max_s = max_bt * BLOCK_SIZE
fxmarty's avatar
fxmarty committed
1260
1261
1262

            if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
                torch.cuda.tunable.tuning_enable(False)
1263
            _, batch, _ = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
1264
        except torch.cuda.OutOfMemoryError as e:
1265
            raise RuntimeError(
1266
1267
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
1268
            ) from e
1269

Nicolas Patry's avatar
Nicolas Patry committed
1270
        synchronize(self.device)
1271

1272
1273
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
1274
1275
1276
1277
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

Nicolas Patry's avatar
Nicolas Patry committed
1278
        free_memory = get_free_memory(self.device, MEMORY_FRACTION)
drbh's avatar
drbh committed
1279
        batch_num_blocks = batch.num_blocks if batch is not None else 0
1280
1281

        num_blocks = (
1282
            # Leave 5% for some wiggle room
1283
            int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size)
1284
            # Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
drbh's avatar
drbh committed
1285
            + batch_num_blocks
1286
1287
        )

1288
        del batch
1289

1290
        self.init_kv_cache(
1291
1292
1293
1294
1295
1296
1297
1298
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
            self.dtype,
            self.device,
        )

fxmarty's avatar
fxmarty committed
1299
1300
1301
1302
1303
        if SYSTEM == "rocm":
            if (
                os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
                or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
            ):
1304
1305
                torch.cuda.tunable.enable()

fxmarty's avatar
fxmarty committed
1306
1307
1308
1309
1310
1311
1312
1313
                if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
                    torch.cuda.tunable.tuning_enable(True)

                if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None:
                    tuning_sequences = [
                        int(val)
                        for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
                    ]
1314
                elif CUDA_GRAPHS is not None:
fxmarty's avatar
fxmarty committed
1315
                    tuning_sequences = CUDA_GRAPHS
1316
1317
1318
                else:
                    # For seqlen = 1, we dispatch to LLMM1 kernel.
                    tuning_sequences = [2, 3, 4, 5, 6, 7]
fxmarty's avatar
fxmarty committed
1319
1320
1321

                tunableop_filepath = os.path.join(
                    HUGGINGFACE_HUB_CACHE,
drbh's avatar
drbh committed
1322
                    f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
fxmarty's avatar
fxmarty committed
1323
1324
                )

1325
1326
1327
                log_master(
                    logger.info,
                    f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
fxmarty's avatar
fxmarty committed
1328
1329
1330
                )

                if os.path.isfile(tunableop_filepath):
1331
1332
1333
                    log_master(
                        logger.info,
                        f"The file {tunableop_filepath} already exists and will be reused.",
fxmarty's avatar
fxmarty committed
1334
1335
1336
1337
1338
1339
                    )
                    torch.cuda.tunable.read_file(tunableop_filepath)

                os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)

                for seqlen in tuning_sequences:
1340
                    log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
fxmarty's avatar
fxmarty committed
1341
1342
1343
1344
                    self.tunableop_warmup(seqlen)
                    torch.cuda.tunable.write_file(tunableop_filepath)
                torch.cuda.tunable.tuning_enable(False)
            else:
1345
1346
1347
                log_master(
                    logger.info,
                    "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
fxmarty's avatar
fxmarty committed
1348
1349
                )

1350
        if CUDA_GRAPHS:
1351
            try:
1352
1353
1354
                log_master(
                    logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
                )
1355
                # Warmup cuda graphs
1356
                for bs in CUDA_GRAPHS:
1357
1358
                    if self.speculate is None or self.speculate + 1 <= bs:
                        self.cuda_graph_warmup(bs, max_s, max_bt)
OlivierDehaene's avatar
OlivierDehaene committed
1359
            except torch.cuda.OutOfMemoryError:
1360
                logger.exception("Decode cuda graph warmup failed")
1361
        else:
1362
1363
1364
            log_master(
                logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
            )
1365

1366
        return int(num_blocks * BLOCK_SIZE)
1367

fxmarty's avatar
fxmarty committed
1368
1369
1370
1371
1372
    def tunableop_warmup(self, seqlen: int):
        input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
        slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)

fxmarty's avatar
fxmarty committed
1373
1374
        # Dummy value, some models (starcoder2) don't accept `None`.
        input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
        prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
        cu_seqlen_prefill = torch.tensor(
            [0, seqlen], device=self.device, dtype=torch.int32
        )
        seqlen = Seqlen(
            input_lengths=input_lengths,
            prefix_lengths=prefix_lens_tensor,
            cu_seqlen_q=cu_seqlen_prefill,
            max_q=1,
            max_k=seqlen,
        )
fxmarty's avatar
fxmarty committed
1386

fxmarty's avatar
fxmarty committed
1387
1388
1389
1390
        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
1391
            cu_seqlen_prefill=cu_seqlen_prefill,
1392
            kv_cache=self.kv_cache,
fxmarty's avatar
fxmarty committed
1393
            block_tables=None,
1394
            seqlen=seqlen,
fxmarty's avatar
fxmarty committed
1395
1396
1397
            slots=slots,
            max_s=seqlen,
            lm_head_indices=None,
1398
            prefill_cache_indices=None,
fxmarty's avatar
fxmarty committed
1399
1400
        )

1401
    def forward(
drbh's avatar
drbh committed
1402
        self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
1403
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1404
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
1405
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
1406
1407
1408
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
1409
            kv_cache = self.kv_cache
OlivierDehaene's avatar
OlivierDehaene committed
1410
1411
1412
1413
1414
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
1415
1416
1417

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
1418
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
1419
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
1420
1421
1422
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1423
1424
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
1425
1426
1427
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1428
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
OlivierDehaene's avatar
OlivierDehaene committed
1429
1430
1431
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1432
1433
1434
            prefix_lens_tensor = (
                batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1435
1436

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
1437
1438
1439
1440
1441
1442
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
1443
1444
1445
1446
1447
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
1448
1449
1450
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
1451
            kv_cache = self.kv_cache
OlivierDehaene's avatar
OlivierDehaene committed
1452
1453
1454
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
1455
            prefix_lens_tensor = batch.prefix_lens_tensor
OlivierDehaene's avatar
OlivierDehaene committed
1456
1457
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
1458

1459
1460
1461
1462
1463
1464
        if cu_seqlen_prefill is None and self.max_past() is not None:
            # In decode, not prefill, we're actually overwriting the KV-cache
            # in a circular buffer mode.
            # This makes sure the max_s for the decode pass is correct.
            max_s = min(self.max_past(), max_s)

1465
        bs = input_ids.shape[0]
OlivierDehaene's avatar
OlivierDehaene committed
1466
1467
1468
1469
1470
1471
1472
1473
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None

        if cu_seqlen_prefill is not None or cuda_graph is None:
1474
            if ATTENTION == "flashinfer":
Nicolas Patry's avatar
Nicolas Patry committed
1475
1476
1477
1478
1479
                block_tables = block_tables_to_ragged(
                    block_tables=block_tables,
                    input_lengths=batch.input_lengths,
                    prefix_lens=batch.prefix_lens,
                )
1480
            with self._forward_context(
1481
                block_tables=block_tables,
1482
                cu_seqlen_prefill=cu_seqlen_prefill,
1483
                input_lengths_tensor=input_lengths,
Nicolas Patry's avatar
Nicolas Patry committed
1484
                prefix_lens_tensor=prefix_lens_tensor,
1485
            ):
1486
1487
1488
1489
1490
1491
1492
1493
                max_k = (input_lengths + prefix_lens_tensor).max().item()
                seqlen = Seqlen(
                    input_lengths=input_lengths,
                    prefix_lengths=prefix_lens_tensor,
                    cu_seqlen_q=cu_seqlen_prefill,
                    max_q=max_s,
                    max_k=max_k,
                )
1494
1495
1496
1497
1498
1499
1500
                logits, speculative_logits = self.model.forward(
                    input_ids=input_ids,
                    position_ids=position_ids,
                    cu_seqlen_prefill=cu_seqlen_prefill,
                    kv_cache=kv_cache,
                    block_tables=block_tables,
                    slots=slots,
1501
                    seqlen=seqlen,
1502
1503
1504
1505
1506
1507
1508
1509
                    max_s=max_s,
                    prefill_cache_indices=batch.prefill_cache_indices,
                    lm_head_indices=lm_head_indices,
                    adapter_data=adapter_data,
                )
                if batch.prefill_cache_indices is not None:
                    batch.prefill_cache_indices = None
                return logits, speculative_logits
1510
1511
1512
1513
1514

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
Nicolas Patry's avatar
Nicolas Patry committed
1515
1516
1517
1518
1519
1520
        if ATTENTION == "flashinfer":
            block_tables = block_tables_to_ragged(
                block_tables=block_tables,
                input_lengths=batch.input_lengths,
                prefix_lens=batch.prefix_lens,
            )
1521
            # assert block_tables.shape[0] >= slots.shape[0]
Nicolas Patry's avatar
Nicolas Patry committed
1522
1523
1524
1525
1526
            cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
        else:
            cuda_graph["block_tables"][
                : block_tables.shape[0], : block_tables.shape[1]
            ] = block_tables
1527
1528
1529
1530

        # XXX: This is working only because block 0 is reserved for the healthcheck
        # so it doesn't matter if we override it with bogus values.
        cuda_graph["slots"].fill_(0)
1531
1532
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
1533
1534
1535
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
        cuda_graph["prefix_lengths"].zero_()
        cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor
1536

1537
        with self._forward_context(
Nicolas Patry's avatar
Nicolas Patry committed
1538
            block_tables=cuda_graph["block_tables"],
1539
            cu_seqlen_prefill=None,
Nicolas Patry's avatar
Nicolas Patry committed
1540
            input_lengths_tensor=cuda_graph["input_lengths"],
1541
1542
            prefix_lens_tensor=cuda_graph["prefix_lengths"],
            state=cuda_graph["state"],
1543
1544
1545
1546
        ):
            # Replay the graph
            cuda_graph["graph"].replay()

1547
        # Slice output to the correct shape
1548
1549
1550
1551
1552
1553
1554
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
1555
1556
1557
1558

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
1559
1560
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
1561
        prefill = batch.cu_seqlen_prefill is not None
1562
        prefill_logprobs = batch.prefill_next_token_indices is not None
1563

drbh's avatar
drbh committed
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
        # Update adapter indices for speculative tokens (if present)
        adapter_meta = batch.adapter_meta
        if batch.speculative_ids is not None:
            B, speculative_length = batch.speculative_ids.shape
            new_length = speculative_length + 1
            adapter_indices = (
                adapter_meta.adapter_indices.unsqueeze(-1)
                .expand(B, new_length)
                .reshape(-1)
            )
            adapter_segments = adapter_meta.adapter_segments * new_length
            adapter_meta = AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_meta.adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_meta.segment_indices,
            )

        # Assign pointers to adapter weights
        # TODO(travis): don't update this if indices haven't changed
        adapter_data = AdapterBatchData.from_meta(
            adapter_meta,
            self.layer_to_adapter_weights,
            prefill,
            batch.prefill_head_indices,
        )

        out, speculative_logits = self.forward(batch, adapter_data)
1592

1593
1594
        if prefill:
            next_token_logits = (
1595
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
1596
            )
Nicolas Patry's avatar
Nicolas Patry committed
1597
1598
            if speculative_logits is not None:
                speculative_logits = (
OlivierDehaene's avatar
OlivierDehaene committed
1599
1600
1601
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
Nicolas Patry's avatar
Nicolas Patry committed
1602
                )
drbh's avatar
drbh committed
1603
1604
1605
1606
            next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
                len(batch)
            )

1607
1608
        else:
            next_token_logits = out
drbh's avatar
drbh committed
1609
            next_adapter_indices = batch.adapter_meta.adapter_indices
1610

Nicolas Patry's avatar
Nicolas Patry committed
1611
        speculate = get_speculate()
OlivierDehaene's avatar
OlivierDehaene committed
1612
1613
1614
1615
1616
1617
1618
1619
1620
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen],
            next_token_logits,
Nicolas Patry's avatar
Nicolas Patry committed
1621
            speculate,
OlivierDehaene's avatar
OlivierDehaene committed
1622
1623
            batch.speculative_ids,
            speculative_logits,
1624
1625
        )

Nicolas Patry's avatar
Nicolas Patry committed
1626
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
1627
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1628
1629
        )

1630
        if prefill:
1631
            if len(batch) > 1 and prefill_logprobs:
1632
1633
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
1634
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
1635
1636

            next_position_ids = batch.position_ids.new_empty(len(batch))
1637
1638
1639
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
1640
1641
1642
1643
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

1644
1645
1646
1647
1648
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
1649
        stopped = True
1650
1651

        # Zipped iterator
OlivierDehaene's avatar
OlivierDehaene committed
1652
        iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
1653

1654
1655
1656
1657
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

1658
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1659
        index = 0
OlivierDehaene's avatar
OlivierDehaene committed
1660
        for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
1661
            # Indexing metadata
1662
1663
1664
            start_index = cumulative_length
            end_index = cumulative_length + input_length

1665
            if prefill:
1666
1667
1668
1669
1670
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

1671
1672
1673
1674
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

drbh's avatar
drbh committed
1675
1676
1677
1678
1679
1680
                # Initialize adapter indices
                # In decode, we only have one token per row in the batch, so grab last index
                next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
                    end_index - 1
                ]

1681
1682
                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
1683
1684
                if prefill_logprobs:
                    if len(batch) > 1:
drbh's avatar
drbh committed
1685
1686
1687
                        prefill_tokens_indices[out_start_index : out_end_index - 1] = (
                            batch.input_ids[start_index + 1 : start_index + out_length]
                        )
1688
1689
1690
1691
1692
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
1693

Nicolas Patry's avatar
Nicolas Patry committed
1694
1695
1696
            for j in range(n_accepted_ids):
                batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
                index += 1
1697
1698
1699

            cumulative_length += input_length

drbh's avatar
drbh committed
1700
        # Update values
Nicolas Patry's avatar
Nicolas Patry committed
1701
1702
1703
1704
1705
        batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
        batch.speculative_ids = speculative_ids
        batch.position_ids = next_position_ids + accepted_ids
        batch.input_lengths_tensor += accepted_ids
        batch.slot_indices += accepted_ids
drbh's avatar
drbh committed
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
        batch.adapter_meta.adapter_indices = next_adapter_indices

        if prefill:
            # adjust segment lengths to account for all request lengths being 1 during decoding
            adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
            batch.adapter_meta.adapter_segments = torch.tensor(
                adapter_segments,
                dtype=torch.int32,
                device=batch.adapter_meta.adapter_segments.device,
            )
1716

1717
        if prefill and prefill_logprobs:
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
1728
        next_token_ids = next_input_ids.tolist()
1729
1730
        accepted_ids = accepted_ids.tolist()
        start_decode = time.time_ns()
1731
1732
1733
1734
1735

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
1736
1737
            batch.prefix_offsets,
            batch.read_offsets,
1738
1739
            batch.stopping_criterias,
            batch.all_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1740
            batch.prefix_ids,
1741
1742
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
1743
            batch.top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1744
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1745
1746
            batch_top_token_ids,
            batch_top_token_logprobs,
1747
1748
1749
        )

        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1750
        index = 0
1751
1752
1753
        for i, (
            request,
            input_length,
1754
1755
            prefix_offset,
            read_offset,
1756
1757
            stopping_criteria,
            all_input_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1758
            prefix_ids,
1759
1760
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
1761
            top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1762
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1763
1764
            top_token_ids,
            top_token_logprobs,
1765
        ) in enumerate(iterator):
1766
            # Append next token to all tokens
Nicolas Patry's avatar
Nicolas Patry committed
1767
1768
1769
            next_token_texts = []
            left = 0

1770
            if n_accepted_ids > 1:
1771
                log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
1772

Nicolas Patry's avatar
Nicolas Patry committed
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
            current_stopped = False
            for j in range(index, index + n_accepted_ids):
                # Generated token
                next_token_id = next_token_ids[j]
                all_input_ids.append(next_token_id)
                next_token_text, prefix_offset, read_offset = self.decode_token(
                    all_input_ids,
                    prefix_offset,
                    read_offset,
                )
                next_token_texts.append(next_token_text)
1784

Nicolas Patry's avatar
Nicolas Patry committed
1785
1786
1787
1788
                stop, reason = stopping_criteria(
                    next_token_id,
                    next_token_text,
                )
1789

Nicolas Patry's avatar
Nicolas Patry committed
1790
1791
1792
1793
1794
1795
1796
                if stop:
                    left = index + n_accepted_ids - j - 1
                    current_stopped = True
                    break
                else:
                    current_stopped = False
            stopped = stopped and current_stopped
1797

OlivierDehaene's avatar
OlivierDehaene committed
1798
1799
1800
1801
            _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
            _next_token_logprobs = next_token_logprobs[
                index : index + n_accepted_ids - left
            ]
Nicolas Patry's avatar
Nicolas Patry committed
1802
            index += n_accepted_ids
1803

1804
1805
1806
1807
1808
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
1809
1810
                    output_text, _, _ = self.decode_token(
                        all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1811
1812
1813
1814
1815
1816
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
1817
1818
                    )
                    generated_text = GeneratedText(
1819
1820
1821
1822
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
1823
1824
1825
1826
1827
                    )
                else:
                    generated_text = None

                # Prefill
1828
1829
1830
1831
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

1832
                    # Remove generated token to only have prefill and add nan for first prompt token
Nicolas Patry's avatar
Nicolas Patry committed
1833
1834
1835
                    request_prefill_logprobs = (
                        [float("nan")] * (len(prefix_ids) + 1)
                    ) + prefill_logprobs[out_start_index : out_end_index - 1]
1836
1837
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
Nicolas Patry's avatar
Nicolas Patry committed
1838
                        prefix_ids + prefill_token_ids,
1839
1840
1841
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1842
1843

                    prefill_tokens = Tokens(
Nicolas Patry's avatar
Nicolas Patry committed
1844
                        prefix_ids + prefill_token_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1845
1846
1847
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
1848
1849
1850
1851
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
1852
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
1853
                    all_top_tokens = []
drbh's avatar
drbh committed
1854
                    for top_token_ids, top_token_logprobs in zip(
1855
1856
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
1857
1858
1859
1860
1861
1862
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
1863
1864
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
1865
1866
1867
1868
1869
1870
1871
1872
1873
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
1874
1875
1876
                else:
                    top_tokens = None

1877
1878
1879
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1880
1881
1882
1883
1884
1885
                    Tokens(
                        _next_token_ids,
                        _next_token_logprobs,
                        next_token_texts,
                        [nid in self.all_special_ids for nid in _next_token_ids],
                    ),
1886
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
1887
                    top_tokens,
1888
1889
                )

1890
                generations.append(generation)
1891

drbh's avatar
drbh committed
1892
1893
1894
            # accept each new token for this specific request since we may
            # have more than one new token per request with speculative decoding
            for next_token_id in _next_token_ids:
OlivierDehaene's avatar
OlivierDehaene committed
1895
1896
1897
                batch.next_token_chooser = (
                    batch.next_token_chooser.advance_grammar_single(i, next_token_id)
                )
drbh's avatar
drbh committed
1898

1899
            # Update values
1900
            batch.input_lengths[i] = input_length + n_accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1901
1902
            if batch.input_lengths[i] > batch.max_seqlen:
                batch.max_seqlen = batch.input_lengths[i]
1903
1904
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1905
1906
            batch.all_input_ids[i] = all_input_ids

1907
1908
        if stopped:
            # No need to return a batch if we know that all requests stopped
1909
1910
1911
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
1912

1913
1914
1915
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1916

1917
1918
1919
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)
1920
1921
1922
1923
1924
1925

    def _forward_context(
        self,
        *,
        block_tables: torch.Tensor,
        cu_seqlen_prefill: Optional[torch.Tensor],
Nicolas Patry's avatar
Nicolas Patry committed
1926
1927
        input_lengths_tensor: torch.Tensor,
        prefix_lens_tensor: torch.Tensor,
1928
1929
        state: Optional[Any] = None,
    ) -> ContextManager:
1930
        if ATTENTION != "flashinfer":
1931
1932
            return nullcontext()

Nicolas Patry's avatar
Nicolas Patry committed
1933
        from text_generation_server.layers.attention.flashinfer import (
1934
            use_decode_state,
Nicolas Patry's avatar
Nicolas Patry committed
1935
            use_prefill_with_paged_kv_state,
1936
1937
        )

Nicolas Patry's avatar
Nicolas Patry committed
1938
1939
        # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)

1940
        if cu_seqlen_prefill is not None:
Nicolas Patry's avatar
Nicolas Patry committed
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
            return use_prefill_with_paged_kv_state(
                state=(
                    state if state is not None else self.prefill_with_paged_kv_state
                ),
                # block_tables=block_tables_to_ragged(
                #     block_tables=block_tables,
                #     input_lengths=input_lengths,
                #     prefix_lens=prefix_lens,
                # ),
                block_tables=block_tables,
1951
                cu_seqlens=cu_seqlen_prefill,
1952
                input_lengths=input_lengths_tensor + prefix_lens_tensor,
1953
1954
1955
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
Nicolas Patry's avatar
Nicolas Patry committed
1956
                page_size=BLOCK_SIZE,
1957
1958
            )
        else:
Nicolas Patry's avatar
Nicolas Patry committed
1959
            assert input_lengths_tensor is not None
1960
1961
            return use_decode_state(
                state=state if state is not None else self.decode_state,
1962
                input_lengths=input_lengths_tensor + prefix_lens_tensor,
Nicolas Patry's avatar
Nicolas Patry committed
1963
                block_tables=block_tables,
1964
1965
1966
1967
1968
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                page_size=BLOCK_SIZE,
            )
Nicolas Patry's avatar
Nicolas Patry committed
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988


def block_tables_to_ragged(
    *, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int]
) -> torch.Tensor:
    """Convert block table to ragged format compatible with FlashInfer."""
    assert len(input_lengths) == len(prefix_lens)

    total_len = sum(input_lengths) + sum(prefix_lens)
    block_tables_ragged = torch.empty(
        total_len, dtype=torch.int32, device=block_tables.device
    )

    offset = 0
    for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)):
        seq_len = prefix_len + input_length
        block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
        offset += seq_len

    return block_tables_ragged