flash_causal_lm.py 65.1 KB
Newer Older
1
import math
2
import os
3
import time
4
import itertools
5
6
7
import torch
import torch.distributed

8
9
import numpy as np

10
from loguru import logger
11
12
from dataclasses import dataclass
from opentelemetry import trace
13
14
15
16
17
18
from transformers import (
    PreTrainedTokenizerBase,
    AutoConfig,
    AutoTokenizer,
    GenerationConfig,
)
Daniël de Kok's avatar
Daniël de Kok committed
19
from typing import Iterable, Optional, Tuple, List, Type, Dict
fxmarty's avatar
fxmarty committed
20

drbh's avatar
drbh committed
21
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
fxmarty's avatar
fxmarty committed
22
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
Daniël de Kok's avatar
Daniël de Kok committed
23
from text_generation_server.utils.chunks import concat_text_chunks
Nicolas Patry's avatar
Nicolas Patry committed
24
from text_generation_server.utils.import_utils import SYSTEM
OlivierDehaene's avatar
OlivierDehaene committed
25
from text_generation_server.models import Model
26
from text_generation_server.utils.tokens import batch_top_tokens
27
from text_generation_server.utils.dist import RANK
Nicolas Patry's avatar
Nicolas Patry committed
28
from text_generation_server.utils.speculate import get_speculate
29
30
31
32
33
34
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
    hub,
)
35
36
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
37
    Tokens,
38
39
40
41
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
Nicolas Patry's avatar
Nicolas Patry committed
42
43
from text_generation_server.models.globals import (
    MEM_POOL,
44
45
    FLASH_DECODING,
    BLOCK_SIZE,
Nicolas Patry's avatar
Nicolas Patry committed
46
47
48
49
    CUDA_GRAPHS,
    get_adapter_to_index,
    MODEL_ID,
)
50
from text_generation_server.layers.attention import Seqlen
51
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
52
from text_generation_server.utils.dist import MEMORY_FRACTION
drbh's avatar
drbh committed
53
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
54

Nicolas Patry's avatar
Nicolas Patry committed
55
from text_generation_server.utils.import_utils import (
Nicolas Patry's avatar
Nicolas Patry committed
56
57
58
    empty_cache,
    synchronize,
    get_free_memory,
Nicolas Patry's avatar
Nicolas Patry committed
59
60
)

Nicolas Patry's avatar
Nicolas Patry committed
61
62
tracer = trace.get_tracer(__name__)

63
64
65
66
67
68
69
70
71
72
73
74
75
76

# Will be set in init
SLIDING_WINDOW: Optional[int] = None


def set_sliding_window(sliding_window: int):
    global SLIDING_WINDOW
    SLIDING_WINDOW = sliding_window


def get_sliding_windows() -> int:
    global SLIDING_WINDOW
    return SLIDING_WINDOW

77

78
79
80
81
@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
82
83
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
84
85

    # Decoder values
86
87
    input_ids: torch.Tensor
    position_ids: torch.Tensor
88
    speculative_ids: Optional[torch.Tensor]
89

90
91
92
93
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
94
95
96
    # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
    # as we only keep SLIDING_WINDOW values instead of the whole tensor
    prefill_cache_indices: Optional[torch.Tensor]
97
98
99
100
101
102
103
104
105
106

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor

    # list of length b of list of length s_i // block_size
107
    block_tables: List[List[int]]
108
    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
109
    block_tables_tensor: torch.Tensor
110
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
111
    slots: torch.Tensor
112

113
114
    max_seqlen: int

115
116
117
118
119
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

120
121
    # All tokens
    all_input_ids: List[List[int]]
122
    all_input_ids_tensor: torch.Tensor
123
124
125

    # Lengths of all generations present in the batch
    input_lengths: List[int]
126
    input_lengths_tensor: torch.Tensor
127
128
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
129
130

    # Generation helpers
131
    next_token_chooser: HeterogeneousNextTokenChooser
132
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
133
134
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
135

drbh's avatar
drbh committed
136
137
138
    # Adapter metadata for each request
    adapter_meta: AdapterBatchMetadata

139
    # Number of blocks in this batch
140
    num_blocks: int
141
142
    # Maximum number of blocks
    max_blocks: int
143

144
145
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
146
            id=self.batch_id,
147
            request_ids=[r.id for r in self.requests],
148
            size=len(self),
149
            max_tokens=self.num_blocks * BLOCK_SIZE,
150
151
152
        )

    @classmethod
Daniël de Kok's avatar
Daniël de Kok committed
153
154
155
    def batch_tokenized_inputs(
        cls, requests: Iterable[generate_pb2.Request], tokenizer
    ):
156
157
        batch_inputs = []
        max_truncation = 0
158
        for r in requests:
Daniël de Kok's avatar
Daniël de Kok committed
159
            batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
160
161
162
163
164
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]
165
        return batch_tokenized_inputs
166

drbh's avatar
drbh committed
167
168
169
170
171
172
173
174
175
    @classmethod
    def from_tokenized(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        batch_tokenized_inputs,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
176
        sliding_window = get_sliding_windows()
177
        position_ids = []
178
        cu_seqlen_prefill = [0]
179
180
        start_slots = []
        slot_indices = []
181
        prefill_cache_indices = []
182
183

        input_lengths = []
184
185
        prefix_offsets = []
        read_offsets = []
186
        all_input_ids = []
187
        requests_idx_mapping = {}
188

189
190
191
192
193
194
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

195
        next_token_chooser_parameters = []
196
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
197
        top_n_tokens = []
198

drbh's avatar
drbh committed
199
200
201
        adapter_indices_list = []
        adapter_set = set()

202
203
        # Cumulative length
        cumulative_length = 0
204
        cumulative_max_length = 0
205
        prefill_out_cumulative_length = 0
206

207
        num_blocks = 0
208
        max_seqlen = 0
209
        max_length = 0
210
        max_blocks = 0
211

212
213
214
        block_tables = []
        slots = []

215
        # Parse batch
216
217
218
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
219
220
221
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

222
            tokenized_input = tokenized_input[-r.truncate :]
223
224
225
226
227
            if (
                tokenized_input[0] == tokenizer.bos_token_id
                and tokenized_input[1] == tokenizer.bos_token_id
            ):
                tokenized_input = tokenized_input[1:]
228

229
230
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
231

232
            prefix_offsets.append(input_length - 5)
233
            read_offsets.append(input_length)
234

235
            all_input_ids.append(tokenized_input)
236
237

            # Position ids
238
239
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
240
241

            # Add cumulative lengths of all previous inputs
242
            cu_seqlen_prefill.append(cumulative_length + input_length)
243

244
            next_token_chooser_parameters.append(r.parameters)
245

246
247
248
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
249
            max_new_tokens = stopping_criteria.max_new_tokens
250
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
251
            top_n_tokens.append(r.top_n_tokens)
252

Nicolas Patry's avatar
Nicolas Patry committed
253
254
            ADAPTER_TO_INDEX = get_adapter_to_index()
            adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
drbh's avatar
drbh committed
255
256
257
            adapter_indices_list.append(torch.full((input_length,), adapter_index))
            adapter_set.add(adapter_index)

258
259
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
260
            speculative_length = get_speculate()
drbh's avatar
drbh committed
261
            speculative_length = 0 if speculative_length is None else speculative_length
Nicolas Patry's avatar
Nicolas Patry committed
262
            total_tokens = input_length + max_new_tokens - 1 + speculative_length
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

            # blocks and slots can be empty (for example in warmup)
            if not r.blocks:
                needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
                request_blocks = [
                    b for b in range(num_blocks, num_blocks + needed_blocks)
                ]
                request_slots = [
                    s
                    for b in request_blocks
                    for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
                ]
            else:
                request_blocks = r.blocks
                request_slots = r.slots

            block_tables.append(request_blocks)
            slots.extend(request_slots[:total_tokens])
            num_blocks += len(request_blocks)
282
283
284
285
286
287
288
289
290
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

291
292
293
294
295
296
297
298
299
            # Create tensor to slice into the kv tensor in prefill
            if sliding_window is not None:
                request_prefill_cache_indices = torch.arange(
                    cumulative_length + max(0, input_length - sliding_window),
                    cumulative_length + input_length,
                    dtype=torch.int64,
                )
                prefill_cache_indices.append(request_prefill_cache_indices)

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

320
321
            # Update
            cumulative_length += input_length
322
323
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
324
            max_blocks = max(max_blocks, len(request_blocks))
OlivierDehaene's avatar
OlivierDehaene committed
325
326
327
            max_length = max(
                max_length, input_length + max_new_tokens + speculative_length
            )
328

drbh's avatar
drbh committed
329
330
331
332
        adapter_indices = torch.cat(adapter_indices_list).to(
            dtype=torch.int64, device=device
        )

333
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
334
            next_token_chooser_parameters, dtype, device, tokenizer
335
        )
336
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
337
338
339
340
341
342
343

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
344

345
346
347
348
349
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

350
351
352
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
353
            slot_indices = torch.cat(slot_indices)
354
355
            if sliding_window is not None:
                prefill_cache_indices = torch.cat(prefill_cache_indices)
356
357
358
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
359
            slot_indices = slot_indices[0]
360
361
            if sliding_window is not None:
                prefill_cache_indices = prefill_cache_indices[0]
362

363
364
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
365
366
367
        )
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
368
369
370
        prefill_cache_indices = (
            prefill_cache_indices.to(device) if sliding_window is not None else None
        )
371
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
372
373
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
374
        )
375

drbh's avatar
drbh committed
376
377
378
379
380
        adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
        adapter_segments = torch.tensor(
            adapter_segments, dtype=torch.int32, device=device
        )

381
382
        if all_prefill_logprobs:
            prefill_head_indices = None
383
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
384
        elif no_prefill_logprobs:
385
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
386
387
388
389
390
391
392
393
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
Nicolas Patry's avatar
Nicolas Patry committed
394
395
396
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
397

398
399
400
401
402
403
404
405
        slots = torch.tensor(slots, dtype=torch.int64, device=device)
        block_tables_tensor = torch.zeros(
            (len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
        )
        for i, request_blocks in enumerate(block_tables):
            block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
        block_tables_tensor = block_tables_tensor.to(device)

406
407
408
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
409
            requests_idx_mapping=requests_idx_mapping,
410
411
            input_ids=input_ids,
            position_ids=position_ids,
412
            cu_seqlen_prefill=cu_seqlen_prefill,
413
            prefill_cache_indices=prefill_cache_indices,
414
415
            start_slots=start_slots,
            slot_indices=slot_indices,
416
417
418
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
419
            max_seqlen=max_seqlen,
420
421
422
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
423
            input_lengths=input_lengths,
424
            input_lengths_tensor=input_lengths_tensor,
425
426
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
427
            all_input_ids=all_input_ids,
428
429
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
430
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
431
432
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
433
            num_blocks=num_blocks,
434
            max_blocks=max_blocks,
drbh's avatar
drbh committed
435
436
437
438
439
440
            adapter_meta=AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            ),
Nicolas Patry's avatar
Nicolas Patry committed
441
            speculative_ids=None,
442
443
        )

444
445
446
447
448
449
450
451
452
453
454
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
        return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)

455
    @tracer.start_as_current_span("filter")
456
457
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
458
459
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
460
        if len(request_ids) == len(self):
461
462
            return self

463
        device = self.input_ids.device
464

465
466
467
        # New values after filtering
        requests_idx_mapping = {}

468
469
470
        # Used to index into tensors
        indices = []

471
472
473
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
474
475
        )

476
        # Create on CPU to only move to GPU once instead of at every copy
477
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
478
479
        max_seqlen = 0

480
        requests = []
481
482
        start_slots = []
        block_tables = []
483
484
        all_input_ids = []

485
        input_lengths = []
486
487
        prefix_offsets = []
        read_offsets = []
488

489
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
490
        top_n_tokens = []
drbh's avatar
drbh committed
491
        adapter_set = set()
492

493
        num_blocks = 0
494
495
496
497
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

498
499
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
500
            indices.append(idx)
501
502
503
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
504
505
506
507

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
508

509
510
511
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
512
513
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
514

515
516
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
517

Nicolas Patry's avatar
Nicolas Patry committed
518
519
            top_n_tokens.append(self.top_n_tokens[idx])

Nicolas Patry's avatar
Nicolas Patry committed
520
521
            ADAPTER_TO_INDEX = get_adapter_to_index()
            adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
drbh's avatar
drbh committed
522
523
            adapter_set.add(adapter_index)

524
            remaining_tokens = (
525
526
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
527

528
            request_block_table = self.block_tables[idx]
529
            num_blocks += len(request_block_table)
530
531
532
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

533
            # Copy to tensor (CPU)
534
            slot_indices[i] = cumulative_max_length + request_input_length - 1
535
536

            # Set slice
537
538
539
540
541
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
542
543
544
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
545

546
547
            max_blocks = max(max_blocks, len(request_block_table))

548
549
550
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
drbh's avatar
drbh committed
551
        adapter_indices = self.adapter_meta.adapter_indices[indices]
552
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
553
554
555
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
556
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
557
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
OlivierDehaene's avatar
OlivierDehaene committed
558
559
560
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
561
562

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
563

564
        # Move to GPU now that we have the whole tensor
565
        slot_indices = slot_indices.to(device)
566

drbh's avatar
drbh committed
567
568
569
570
571
        adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
        adapter_segments = torch.tensor(
            adapter_segments, dtype=torch.int32, device=device
        )

572
        return type(self)(
573
574
575
576
577
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
578
            cu_seqlen_prefill=None,
579
            prefill_cache_indices=None,
580
581
582
583
584
            start_slots=start_slots,
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
585
            max_seqlen=max_seqlen,
586
587
588
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
589
            input_lengths=input_lengths,
590
            input_lengths_tensor=input_lengths_tensor,
591
592
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
593
594
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
595
            next_token_chooser=next_token_chooser,
596
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
597
598
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
599
            num_blocks=num_blocks,
600
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
601
            speculative_ids=speculative_ids,
drbh's avatar
drbh committed
602
603
604
605
606
607
            adapter_meta=AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            ),
608
609
610
611
612
613
614
615
616
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

617
        num_blocks = 0
618
619
620
621
622
623
624
625
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
626
            num_blocks += b.num_blocks
OlivierDehaene's avatar
OlivierDehaene committed
627
628
629
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
630
631
632
633
634
635
636
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
637
                    + speculative_length
638
639
640
641
642
643
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
644
645
646

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
647
648
649
650
651
652
653
654
655
656
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
657
        )
Nicolas Patry's avatar
Nicolas Patry committed
658
659
660
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
drbh's avatar
drbh committed
661
662
663
664
665
666
667
668
        total_indices_size = sum(
            b.adapter_meta.adapter_indices.shape[0] for b in batches
        )
        adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
            total_indices_size
        )
        adapter_set = set()
        adapter_segment_builder = SegmentConcatBuilder()
669

670
671
        start_slots = []
        block_tables = []
672
673
674
        all_input_ids = []

        input_lengths = []
675
676
        prefix_offsets = []
        read_offsets = []
677

678
        next_token_chooser_parameters = []
679
        fsm_grammar_states = []
680
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
681
        top_n_tokens = []
682

683
        # Cumulative length
684
        cumulative_batch_size = 0
685
        cumulative_slots = 0
drbh's avatar
drbh committed
686
        cumulative_adapter_indices_size = 0
687
688
689

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
690
691
692
693
694
695
696
697

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

698
699
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
700
701
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
702
703
704
705

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
706
707
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
708
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
709
            slots[slots_start_index:slots_end_index] = batch.slots
710

drbh's avatar
drbh committed
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
            # Copy over adapter indices
            adapter_start_index = cumulative_adapter_indices_size
            adapter_end_index = (
                cumulative_adapter_indices_size
                + batch.adapter_meta.adapter_indices.shape[0]
            )
            adapter_indices[adapter_start_index:adapter_end_index] = (
                batch.adapter_meta.adapter_indices
            )
            cumulative_adapter_indices_size = adapter_end_index
            adapter_set.update(batch.adapter_meta.adapter_set)
            adapter_segment_builder.concat(
                batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
            )

726
727
728
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
729

730
731
732
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
733

734
735
736
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
737
738
            all_input_ids.extend(batch.all_input_ids)

739
            input_lengths.extend(batch.input_lengths)
740
741
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
742

743
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
744
            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
745
746
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
747
748
            top_n_tokens.extend(batch.top_n_tokens)

749
            # Update
750
            cumulative_batch_size += len(batch)
751
            cumulative_slots += len(batch.slots)
752

753
        start_slots = torch.concat(start_slots)
754

755
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
756
757
758
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
drbh's avatar
drbh committed
759
            tokenizer=batches[0].next_token_chooser.tokenizer,
760
            fsm_grammar_states=fsm_grammar_states,
761
762
        )

OlivierDehaene's avatar
OlivierDehaene committed
763
764
765
766
767
        speculative_ids = (
            torch.cat([b.speculative_ids for b in batches], dim=0)
            if batches[0].speculative_ids is not None
            else None
        )
Nicolas Patry's avatar
Nicolas Patry committed
768

drbh's avatar
drbh committed
769
770
        adapter_segments, adapter_segment_indices = adapter_segment_builder.build()

771
        return cls(
772
773
            batch_id=batches[0].batch_id,
            requests=requests,
774
            requests_idx_mapping=requests_idx_mapping,
775
776
            input_ids=input_ids,
            position_ids=position_ids,
777
            cu_seqlen_prefill=None,
778
            prefill_cache_indices=None,
779
780
781
782
783
            start_slots=start_slots,
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
784
            max_seqlen=max_seqlen,
785
786
787
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
788
            input_lengths=input_lengths,
789
            input_lengths_tensor=input_lengths_tensor,
790
791
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
792
793
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
794
            next_token_chooser=next_token_chooser,
795
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
796
797
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
798
            num_blocks=num_blocks,
799
            max_blocks=max_blocks,
OlivierDehaene's avatar
OlivierDehaene committed
800
            speculative_ids=speculative_ids,
drbh's avatar
drbh committed
801
802
803
804
805
806
            adapter_meta=AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            ),
807
808
809
810
811
812
        )

    def __len__(self):
        return len(self.requests)


813
814
815
816
817
818
819
820
821
822
823
824
ADAPTER_LAYERS = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}


825
826
827
class FlashCausalLM(Model):
    def __init__(
        self,
drbh's avatar
drbh committed
828
        model_id: str,
829
830
831
832
833
834
835
836
837
838
839
840
841
842
        model_class,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        speculator: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
        lora_adapter_ids: Optional[list] = [],
        tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
        config_class: PreTrainedTokenizerBase = AutoConfig,
        default_dtype=torch.float16,
        aliases=None,
        # Used for Santacoder override of config
        num_kv_heads=None,
        skip_special_tokens: bool = True,
843
    ):
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = default_dtype if dtype is None else dtype
        elif SYSTEM == "ipex":
            if hasattr(torch, "xpu") and torch.xpu.is_available():
                device = torch.device(f"xpu:{rank}")
                dtype = default_dtype if dtype is None else dtype
            else:
                device = torch.device("cpu")
                # Float16 doesn't exist on target.
                dtype = torch.bfloat16 if dtype is None else dtype
        else:
            raise NotImplementedError(f"{model_class} is only available on GPU")

        tokenizer = tokenizer_class.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )
        try:
            generation_config = GenerationConfig.from_pretrained(
                model_id, revision=revision, trust_remote_code=trust_remote_code
            )
            if isinstance(generation_config.eos_token_id, (list, set)):
                # TODO Huge hack
                tokenizer._eos_token_ids = set(generation_config.eos_token_id)
        except Exception:
            pass

        config = config_class.from_pretrained(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        config.quantize = quantize
        config.speculator = speculator

        torch.distributed.barrier(group=self.process_group)

        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(
            filenames, device, dtype, process_group=self.process_group, aliases=aliases
        )
        if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
            weights._set_gptq_params(model_id, revision)

        prefix = ""
        model = model_class(prefix, config, weights)
        torch.distributed.barrier(group=self.process_group)

        # VLM models define the config we care about in their text_config
        text_config = getattr(config, "text_config", None)
        if text_config is not None:
            config = text_config
899
900
901
902
903
904

        if getattr(config, "sliding_window", None) is not None:
            set_sliding_window(config.sliding_window)
        else:
            config.sliding_window = None

905
906
907
        self.num_layers = config.num_hidden_layers
        # Validation is done in the model itself
        if num_kv_heads is None:
908
909
            num_kv_heads = getattr(config, "num_key_value_heads", None)
            # GPT-2 workaround
910
            if num_kv_heads is None:
911
912
913
                num_kv_heads = getattr(config, "n_head", None)
        if num_kv_heads is None:
            raise ValueError("Cannot get the number of key/value heads")
914
915
916
917
918
919
        self.num_kv_heads = (
            num_kv_heads // self.process_group.size()
            if num_kv_heads > 1
            else num_kv_heads
        )
        assert self.num_kv_heads > 0
920
        self.head_size = config.hidden_size // config.num_attention_heads
921

922
        self.cuda_graphs = {}
923
        self.kv_cache = []
924

925
        super().__init__(
drbh's avatar
drbh committed
926
            model_id=model_id,
927
            model=model,
928
929
930
931
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
932
933
            rank=rank,
            world_size=world_size,
934
            sliding_window=config.sliding_window,
935
936
937
938
939
940
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
    def max_past(self) -> int:
        return getattr(self.model, "max_past", None)

    def init_kv_cache(
        self,
        num_blocks: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        self.kv_cache = []
        empty_cache()

        element_size = torch.tensor([], dtype=dtype).element_size()
Wang, Yi's avatar
Wang, Yi committed
957
958
959
960
        if SYSTEM == "ipex" and device.type == "xpu":
            x = 1
        else:
            x = BLOCK_SIZE // element_size
961

962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
        if FLASH_DECODING:
            self.kv_cache = [
                (
                    torch.empty(
                        (num_blocks, BLOCK_SIZE, num_heads, head_size),
                        dtype=dtype,
                        device=device,
                    ),
                    torch.empty(
                        (num_blocks, BLOCK_SIZE, num_heads, head_size),
                        dtype=dtype,
                        device=device,
                    ),
                )
                for _ in range(num_layers)
            ]
        elif SYSTEM == "ipex" and device == torch.device("cpu"):
Wang, Yi's avatar
Wang, Yi committed
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
            self.kv_cache = [
                (
                    torch.empty(
                        (num_blocks, num_heads, BLOCK_SIZE, head_size),
                        dtype=dtype,
                        device=device,
                    ),
                    torch.empty(
                        (num_blocks, num_heads, BLOCK_SIZE, head_size),
                        dtype=dtype,
                        device=device,
                    ),
                )
                for _ in range(num_layers)
            ]
        else:
            self.kv_cache = [
                (
                    torch.empty(
                        (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
                        dtype=dtype,
                        device=device,
                    ),
                    torch.empty(
                        (num_blocks, num_heads, head_size, BLOCK_SIZE),
                        dtype=dtype,
                        device=device,
                    ),
                )
                for _ in range(num_layers)
            ]
1010

1011
1012
1013
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
1014
        slots = torch.arange(bs, dtype=torch.int64, device=self.device)
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
        input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
        block_tables = (
            torch.arange(max_bt, dtype=torch.int32, device=self.device)
            .repeat(bs)
            .reshape((bs, max_bt))
        )

        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
1025
            "kv_cache": self.kv_cache,
1026
1027
1028
1029
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths,
        }
1030
        input_lengths_ = Seqlen(input_lengths=input_lengths)
1031
1032
1033
1034
1035
1036
1037
1038
1039
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs]["graph"] = graph

        torch.cuda.synchronize()
        # Run once outside to warmup
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=None,
1040
            kv_cache=self.kv_cache,
1041
1042
            block_tables=block_tables,
            slots=slots,
1043
            input_lengths=input_lengths_,
1044
            max_s=max_s,
1045
            prefill_cache_indices=None,
1046
1047
1048
1049
1050
            lm_head_indices=None,
        )
        torch.cuda.synchronize()

        with torch.cuda.graph(graph, pool=MEM_POOL):
1051
            input_lengths = Seqlen(input_lengths=input_lengths)
1052
            logits, speculative_logits = self.model.forward(
1053
1054
1055
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
1056
                kv_cache=self.kv_cache,
1057
1058
1059
1060
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
1061
                prefill_cache_indices=None,
1062
1063
                lm_head_indices=None,
            )
1064
1065
            self.cuda_graphs[bs]["logits"] = logits
            self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
1066
1067
        torch.cuda.synchronize()

1068
    def warmup(self, batch: FlashCausalLMBatch):
1069
        # The warmup batch is the biggest batch we could ever receive
Nicolas Patry's avatar
Nicolas Patry committed
1070
1071
        empty_cache()

1072
        try:
1073
1074
            self.init_kv_cache(
                batch.num_blocks,
1075
1076
1077
1078
1079
1080
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
                self.dtype,
                self.device,
            )
1081
            max_bt = batch.max_blocks
1082
            max_s = max_bt * BLOCK_SIZE
fxmarty's avatar
fxmarty committed
1083
1084
1085

            if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
                torch.cuda.tunable.tuning_enable(False)
1086
            _, batch, _ = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
1087
        except torch.cuda.OutOfMemoryError as e:
1088
            raise RuntimeError(
1089
1090
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
1091
            ) from e
1092

Nicolas Patry's avatar
Nicolas Patry committed
1093
        synchronize(self.device)
1094

1095
1096
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
1097
1098
1099
1100
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

Nicolas Patry's avatar
Nicolas Patry committed
1101
        free_memory = get_free_memory(self.device, MEMORY_FRACTION)
drbh's avatar
drbh committed
1102
        batch_num_blocks = batch.num_blocks if batch is not None else 0
1103
1104

        num_blocks = (
1105
1106
            # Leave 5% for some wiggle room
            int((free_memory * 0.95) // total_cache_size)
1107
            # Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
drbh's avatar
drbh committed
1108
            + batch_num_blocks
1109
1110
        )

1111
        del batch
1112

1113
        self.init_kv_cache(
1114
1115
1116
1117
1118
1119
1120
1121
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
            self.dtype,
            self.device,
        )

fxmarty's avatar
fxmarty committed
1122
1123
1124
1125
1126
        if SYSTEM == "rocm":
            if (
                os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
                or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
            ):
1127
1128
                torch.cuda.tunable.enable()

fxmarty's avatar
fxmarty committed
1129
1130
1131
1132
1133
1134
1135
1136
                if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
                    torch.cuda.tunable.tuning_enable(True)

                if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None:
                    tuning_sequences = [
                        int(val)
                        for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
                    ]
1137
                elif CUDA_GRAPHS is not None:
fxmarty's avatar
fxmarty committed
1138
                    tuning_sequences = CUDA_GRAPHS
1139
1140
1141
                else:
                    # For seqlen = 1, we dispatch to LLMM1 kernel.
                    tuning_sequences = [2, 3, 4, 5, 6, 7]
fxmarty's avatar
fxmarty committed
1142
1143
1144

                tunableop_filepath = os.path.join(
                    HUGGINGFACE_HUB_CACHE,
Nicolas Patry's avatar
Nicolas Patry committed
1145
                    f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
fxmarty's avatar
fxmarty committed
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
                )

                logger.info(
                    f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`."
                )

                if os.path.isfile(tunableop_filepath):
                    logger.info(
                        f"The file {tunableop_filepath} already exists and will be reused."
                    )
                    torch.cuda.tunable.read_file(tunableop_filepath)

                os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)

                for seqlen in tuning_sequences:
                    logger.info(f"Warming up TunableOp for seqlen={seqlen}")
                    self.tunableop_warmup(seqlen)
                    torch.cuda.tunable.write_file(tunableop_filepath)
                torch.cuda.tunable.tuning_enable(False)
            else:
                logger.info(
                    "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp."
                )

1170
        if CUDA_GRAPHS:
1171
            try:
1172
                logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
1173
                # Warmup cuda graphs
1174
                for bs in CUDA_GRAPHS:
1175
1176
                    if self.speculate is None or self.speculate + 1 <= bs:
                        self.cuda_graph_warmup(bs, max_s, max_bt)
OlivierDehaene's avatar
OlivierDehaene committed
1177
            except torch.cuda.OutOfMemoryError:
1178
                logger.exception(f"Decode cuda graph warmup failed")
1179
1180
        else:
            logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
1181

1182
        return int(num_blocks * BLOCK_SIZE)
1183

fxmarty's avatar
fxmarty committed
1184
1185
1186
1187
1188
    def tunableop_warmup(self, seqlen: int):
        input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
        slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)

fxmarty's avatar
fxmarty committed
1189
1190
        # Dummy value, some models (starcoder2) don't accept `None`.
        input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
1191
        input_lengths = Seqlen(input_lengths=input_lengths)
fxmarty's avatar
fxmarty committed
1192

fxmarty's avatar
fxmarty committed
1193
1194
1195
1196
1197
1198
1199
        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=torch.tensor(
                [0, seqlen], device=self.device, dtype=torch.int32
            ),
1200
            kv_cache=self.kv_cache,
fxmarty's avatar
fxmarty committed
1201
            block_tables=None,
fxmarty's avatar
fxmarty committed
1202
            input_lengths=input_lengths,
fxmarty's avatar
fxmarty committed
1203
1204
1205
            slots=slots,
            max_s=seqlen,
            lm_head_indices=None,
1206
            prefill_cache_indices=None,
fxmarty's avatar
fxmarty committed
1207
1208
        )

1209
    def forward(
drbh's avatar
drbh committed
1210
        self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
1211
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1212
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
1213
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
1214
1215
1216
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
1217
            kv_cache = self.kv_cache
OlivierDehaene's avatar
OlivierDehaene committed
1218
1219
1220
1221
1222
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
1223
1224
1225

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
1226
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
1227
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
1228
1229
1230
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1231
1232
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
1233
1234
1235
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1236
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
OlivierDehaene's avatar
OlivierDehaene committed
1237
1238
1239
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1240
1241

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
1242
1243
1244
1245
1246
1247
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
1248
1249
1250
1251
1252
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
1253
1254
1255
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
1256
            kv_cache = self.kv_cache
OlivierDehaene's avatar
OlivierDehaene committed
1257
1258
1259
1260
1261
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
1262

1263
1264
1265
1266
1267
1268
        if cu_seqlen_prefill is None and self.max_past() is not None:
            # In decode, not prefill, we're actually overwriting the KV-cache
            # in a circular buffer mode.
            # This makes sure the max_s for the decode pass is correct.
            max_s = min(self.max_past(), max_s)

1269
        bs = input_ids.shape[0]
OlivierDehaene's avatar
OlivierDehaene committed
1270
1271
1272
1273
1274
1275
1276
1277
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None

        if cu_seqlen_prefill is not None or cuda_graph is None:
1278
            input_lengths = Seqlen(input_lengths=input_lengths)
1279
            logits, speculative_logits = self.model.forward(
1280
1281
1282
1283
1284
1285
1286
1287
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=cu_seqlen_prefill,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
1288
                prefill_cache_indices=batch.prefill_cache_indices,
1289
                lm_head_indices=lm_head_indices,
drbh's avatar
drbh committed
1290
                adapter_data=adapter_data,
1291
            )
1292
1293
1294
            if batch.prefill_cache_indices is not None:
                batch.prefill_cache_indices = None
            return logits, speculative_logits
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
        cuda_graph["block_tables"][
            : block_tables.shape[0], : block_tables.shape[1]
        ] = block_tables
        cuda_graph["slots"].fill_(-1)
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths

        # Replay the graph
        cuda_graph["graph"].replay()
        # Slice output to the correct shape
1311
1312
1313
1314
1315
1316
1317
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
1318
1319
1320
1321

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
1322
1323
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
1324
        prefill = batch.cu_seqlen_prefill is not None
1325
        prefill_logprobs = batch.prefill_next_token_indices is not None
1326

drbh's avatar
drbh committed
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
        # Update adapter indices for speculative tokens (if present)
        adapter_meta = batch.adapter_meta
        if batch.speculative_ids is not None:
            B, speculative_length = batch.speculative_ids.shape
            new_length = speculative_length + 1
            adapter_indices = (
                adapter_meta.adapter_indices.unsqueeze(-1)
                .expand(B, new_length)
                .reshape(-1)
            )
            adapter_segments = adapter_meta.adapter_segments * new_length
            adapter_meta = AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_meta.adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_meta.segment_indices,
            )

        # Assign pointers to adapter weights
        # TODO(travis): don't update this if indices haven't changed
        adapter_data = AdapterBatchData.from_meta(
            adapter_meta,
            self.layer_to_adapter_weights,
            prefill,
            batch.prefill_head_indices,
        )

        out, speculative_logits = self.forward(batch, adapter_data)
1355

1356
1357
        if prefill:
            next_token_logits = (
1358
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
1359
            )
Nicolas Patry's avatar
Nicolas Patry committed
1360
1361
            if speculative_logits is not None:
                speculative_logits = (
OlivierDehaene's avatar
OlivierDehaene committed
1362
1363
1364
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
Nicolas Patry's avatar
Nicolas Patry committed
1365
                )
drbh's avatar
drbh committed
1366
1367
1368
1369
            next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
                len(batch)
            )

1370
1371
        else:
            next_token_logits = out
drbh's avatar
drbh committed
1372
            next_adapter_indices = batch.adapter_meta.adapter_indices
1373

Nicolas Patry's avatar
Nicolas Patry committed
1374
        speculate = get_speculate()
OlivierDehaene's avatar
OlivierDehaene committed
1375
1376
1377
1378
1379
1380
1381
1382
1383
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen],
            next_token_logits,
Nicolas Patry's avatar
Nicolas Patry committed
1384
            speculate,
OlivierDehaene's avatar
OlivierDehaene committed
1385
1386
            batch.speculative_ids,
            speculative_logits,
1387
1388
        )

Nicolas Patry's avatar
Nicolas Patry committed
1389
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
1390
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1391
1392
        )

1393
        if prefill:
1394
            if len(batch) > 1 and prefill_logprobs:
1395
1396
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
1397
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
1398
1399

            next_position_ids = batch.position_ids.new_empty(len(batch))
1400
1401
1402
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
1403
1404
1405
1406
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

1407
1408
1409
1410
1411
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
1412
        stopped = True
1413
1414

        # Zipped iterator
OlivierDehaene's avatar
OlivierDehaene committed
1415
        iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
1416

1417
1418
1419
1420
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

1421
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1422
        index = 0
OlivierDehaene's avatar
OlivierDehaene committed
1423
        for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
1424
            # Indexing metadata
1425
1426
1427
            start_index = cumulative_length
            end_index = cumulative_length + input_length

1428
            if prefill:
1429
1430
1431
1432
1433
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

1434
1435
1436
1437
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

drbh's avatar
drbh committed
1438
1439
1440
1441
1442
1443
                # Initialize adapter indices
                # In decode, we only have one token per row in the batch, so grab last index
                next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
                    end_index - 1
                ]

1444
1445
                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
1446
1447
                if prefill_logprobs:
                    if len(batch) > 1:
drbh's avatar
drbh committed
1448
1449
1450
                        prefill_tokens_indices[out_start_index : out_end_index - 1] = (
                            batch.input_ids[start_index + 1 : start_index + out_length]
                        )
1451
1452
1453
1454
1455
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
1456

Nicolas Patry's avatar
Nicolas Patry committed
1457
1458
1459
            for j in range(n_accepted_ids):
                batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
                index += 1
1460
1461
1462

            cumulative_length += input_length

drbh's avatar
drbh committed
1463
        # Update values
Nicolas Patry's avatar
Nicolas Patry committed
1464
1465
1466
1467
1468
        batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
        batch.speculative_ids = speculative_ids
        batch.position_ids = next_position_ids + accepted_ids
        batch.input_lengths_tensor += accepted_ids
        batch.slot_indices += accepted_ids
drbh's avatar
drbh committed
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
        batch.adapter_meta.adapter_indices = next_adapter_indices

        if prefill:
            # adjust segment lengths to account for all request lengths being 1 during decoding
            adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
            batch.adapter_meta.adapter_segments = torch.tensor(
                adapter_segments,
                dtype=torch.int32,
                device=batch.adapter_meta.adapter_segments.device,
            )
1479

1480
        if prefill and prefill_logprobs:
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
1491
        next_token_ids = next_input_ids.tolist()
1492
1493
        accepted_ids = accepted_ids.tolist()
        start_decode = time.time_ns()
1494
1495
1496
1497
1498

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
1499
1500
            batch.prefix_offsets,
            batch.read_offsets,
1501
1502
            batch.stopping_criterias,
            batch.all_input_ids,
1503
1504
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
1505
            batch.top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1506
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1507
1508
            batch_top_token_ids,
            batch_top_token_logprobs,
1509
1510
1511
        )

        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1512
        index = 0
1513
1514
1515
        for i, (
            request,
            input_length,
1516
1517
            prefix_offset,
            read_offset,
1518
1519
            stopping_criteria,
            all_input_ids,
1520
1521
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
1522
            top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1523
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1524
1525
            top_token_ids,
            top_token_logprobs,
1526
        ) in enumerate(iterator):
1527
            # Append next token to all tokens
Nicolas Patry's avatar
Nicolas Patry committed
1528
1529
1530
            next_token_texts = []
            left = 0

1531
1532
1533
1534
            if n_accepted_ids > 1:
                if RANK == 0:
                    logger.debug(f"Speculated ids {n_accepted_ids - 1}")

Nicolas Patry's avatar
Nicolas Patry committed
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
            current_stopped = False
            for j in range(index, index + n_accepted_ids):
                # Generated token
                next_token_id = next_token_ids[j]
                all_input_ids.append(next_token_id)
                next_token_text, prefix_offset, read_offset = self.decode_token(
                    all_input_ids,
                    prefix_offset,
                    read_offset,
                )
                next_token_texts.append(next_token_text)
1546

Nicolas Patry's avatar
Nicolas Patry committed
1547
1548
1549
1550
                stop, reason = stopping_criteria(
                    next_token_id,
                    next_token_text,
                )
1551

Nicolas Patry's avatar
Nicolas Patry committed
1552
1553
1554
1555
1556
1557
1558
                if stop:
                    left = index + n_accepted_ids - j - 1
                    current_stopped = True
                    break
                else:
                    current_stopped = False
            stopped = stopped and current_stopped
1559

OlivierDehaene's avatar
OlivierDehaene committed
1560
1561
1562
1563
            _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
            _next_token_logprobs = next_token_logprobs[
                index : index + n_accepted_ids - left
            ]
Nicolas Patry's avatar
Nicolas Patry committed
1564
            index += n_accepted_ids
1565

1566
1567
1568
1569
1570
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
1571
1572
                    output_text, _, _ = self.decode_token(
                        all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1573
1574
1575
1576
1577
1578
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
1579
1580
                    )
                    generated_text = GeneratedText(
1581
1582
1583
1584
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
1585
1586
1587
1588
1589
                    )
                else:
                    generated_text = None

                # Prefill
1590
1591
1592
1593
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

1594
1595
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
1596
                        out_start_index : out_end_index - 1
1597
1598
1599
1600
1601
1602
1603
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1604
1605

                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
1606
1607
1608
1609
                        prefill_token_ids,
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
1610
1611
1612
1613
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
1614
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
1615
                    all_top_tokens = []
drbh's avatar
drbh committed
1616
                    for top_token_ids, top_token_logprobs in zip(
1617
1618
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
1619
1620
1621
1622
1623
1624
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
1625
1626
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
1627
1628
1629
1630
1631
1632
1633
1634
1635
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
1636
1637
1638
                else:
                    top_tokens = None

1639
1640
1641
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1642
1643
1644
1645
1646
1647
                    Tokens(
                        _next_token_ids,
                        _next_token_logprobs,
                        next_token_texts,
                        [nid in self.all_special_ids for nid in _next_token_ids],
                    ),
1648
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
1649
                    top_tokens,
1650
1651
                )

1652
                generations.append(generation)
1653

drbh's avatar
drbh committed
1654
1655
1656
            # accept each new token for this specific request since we may
            # have more than one new token per request with speculative decoding
            for next_token_id in _next_token_ids:
OlivierDehaene's avatar
OlivierDehaene committed
1657
1658
1659
                batch.next_token_chooser = (
                    batch.next_token_chooser.advance_grammar_single(i, next_token_id)
                )
drbh's avatar
drbh committed
1660

1661
            # Update values
1662
            batch.input_lengths[i] = input_length + n_accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1663
1664
            if batch.input_lengths[i] > batch.max_seqlen:
                batch.max_seqlen = batch.input_lengths[i]
1665
1666
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1667
1668
            batch.all_input_ids[i] = all_input_ids

1669
1670
        if stopped:
            # No need to return a batch if we know that all requests stopped
1671
1672
1673
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
1674

1675
1676
1677
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1678

1679
1680
1681
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750

    @property
    def supports_adapter_loading(self) -> bool:
        return True

    def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
        layer_weights = {}

        prefix = "model.layers"

        # This accounts for VLMs (e.g. LlavaNext, Idefics2)
        # that have a language_model inside of the larger model.
        if hasattr(self.model, "language_model"):
            _model = self.model.language_model
        elif hasattr(self.model, "text_model"):
            _model = self.model.text_model
        else:
            _model = self.model

        for i, layer in enumerate(_model.model.layers):
            layer_weights[(i, "q_proj")] = (
                f"{prefix}.{i}.self_attn.q_proj",
                layer.self_attn.query_key_value,
            )
            layer_weights[(i, "k_proj")] = (
                f"{prefix}.{i}.self_attn.k_proj",
                layer.self_attn.query_key_value,
            )
            layer_weights[(i, "v_proj")] = (
                f"{prefix}.{i}.self_attn.v_proj",
                layer.self_attn.query_key_value,
            )
            layer_weights[(i, "o_proj")] = (
                f"{prefix}.{i}.self_attn.o_proj",
                layer.self_attn.o_proj,
            )

            # TODO: this is a hack to avoid the gate_proj for
            # FlashStarcoder2 that doesnt have these layers
            if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
                layer_weights[(i, "gate_proj")] = (
                    f"{prefix}.{i}.mlp.gate_proj",
                    layer.mlp.gate_up_proj,
                )
                layer_weights[(i, "up_proj")] = (
                    f"{prefix}.{i}.mlp.up_proj",
                    layer.mlp.gate_up_proj,
                )
                layer_weights[(i, "down_proj")] = (
                    f"{prefix}.{i}.mlp.down_proj",
                    layer.mlp.down_proj,
                )

        layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
        return layer_weights

    @property
    def adapter_layers(self) -> List[str]:
        return ADAPTER_LAYERS

    @property
    def default_traced_adapter_layers(self) -> List[str]:
        return ["q_proj", "v_proj"]

    def get_num_layers_for_type(self, layer_type: str) -> int:
        return 1 if layer_type == "lm_head" else len(self.model.model.layers)

    def is_row_parallel(self, layer_type: str) -> bool:
        return layer_type in ROW_PARALLEL