flash_causal_lm.py 63.3 KB
Newer Older
1
import math
2
import os
3
import time
4
5
6
import torch
import torch.distributed

7
8
import numpy as np

9
from loguru import logger
10
11
from dataclasses import dataclass
from opentelemetry import trace
12
13
14
15
16
17
from transformers import (
    PreTrainedTokenizerBase,
    AutoConfig,
    AutoTokenizer,
    GenerationConfig,
)
Daniël de Kok's avatar
Daniël de Kok committed
18
from typing import Iterable, Optional, Tuple, List, Type, Dict
fxmarty's avatar
fxmarty committed
19

drbh's avatar
drbh committed
20
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
fxmarty's avatar
fxmarty committed
21
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
Daniël de Kok's avatar
Daniël de Kok committed
22
from text_generation_server.utils.chunks import concat_text_chunks
Nicolas Patry's avatar
Nicolas Patry committed
23
from text_generation_server.utils.import_utils import SYSTEM
OlivierDehaene's avatar
OlivierDehaene committed
24
from text_generation_server.models import Model
25
from text_generation_server.utils.log import log_master
26
from text_generation_server.utils.tokens import batch_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
27
from text_generation_server.utils.speculate import get_speculate
28
29
30
31
32
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
)
33
34
from text_generation_server.models.types import (
    Batch,
Nicolas Patry's avatar
Nicolas Patry committed
35
    Tokens,
36
37
38
39
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
Nicolas Patry's avatar
Nicolas Patry committed
40
41
from text_generation_server.models.globals import (
    MEM_POOL,
42
43
    FLASH_DECODING,
    BLOCK_SIZE,
Nicolas Patry's avatar
Nicolas Patry committed
44
45
46
47
    CUDA_GRAPHS,
    get_adapter_to_index,
    MODEL_ID,
)
48
from text_generation_server.layers.attention import Seqlen
49
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
50
from text_generation_server.utils.dist import MEMORY_FRACTION
51
from text_generation_server.utils.quantization import get_loader
drbh's avatar
drbh committed
52
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
53

Nicolas Patry's avatar
Nicolas Patry committed
54
from text_generation_server.utils.import_utils import (
Nicolas Patry's avatar
Nicolas Patry committed
55
56
57
    empty_cache,
    synchronize,
    get_free_memory,
Nicolas Patry's avatar
Nicolas Patry committed
58
59
)

Nicolas Patry's avatar
Nicolas Patry committed
60
61
tracer = trace.get_tracer(__name__)

62
63
64
65
66
67
68
69
70
71
72
73
74
75

# Will be set in init
SLIDING_WINDOW: Optional[int] = None


def set_sliding_window(sliding_window: int):
    global SLIDING_WINDOW
    SLIDING_WINDOW = sliding_window


def get_sliding_windows() -> int:
    global SLIDING_WINDOW
    return SLIDING_WINDOW

76

77
78
79
80
@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
81
82
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
83
84

    # Decoder values
85
86
    input_ids: torch.Tensor
    position_ids: torch.Tensor
87
    speculative_ids: Optional[torch.Tensor]
88

89
90
91
92
    # Flash Attention values

    # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
    cu_seqlen_prefill: Optional[torch.Tensor]
93
94
95
    # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
    # as we only keep SLIDING_WINDOW values instead of the whole tensor
    prefill_cache_indices: Optional[torch.Tensor]
96
97
98
99
100
101
102
103
104
105

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor

    # list of length b of list of length s_i // block_size
106
    block_tables: List[List[int]]
107
    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
108
    block_tables_tensor: torch.Tensor
109
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
110
    slots: torch.Tensor
111

112
113
    max_seqlen: int

114
115
116
117
118
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

119
120
    # All tokens
    all_input_ids: List[List[int]]
121
    all_input_ids_tensor: torch.Tensor
122
123
124

    # Lengths of all generations present in the batch
    input_lengths: List[int]
125
    input_lengths_tensor: torch.Tensor
126
127
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
128
129

    # Generation helpers
130
    next_token_chooser: HeterogeneousNextTokenChooser
131
    stopping_criterias: List[StoppingCriteria]
Nicolas Patry's avatar
Nicolas Patry committed
132
133
    top_n_tokens: List[int]
    top_n_tokens_tensor: torch.Tensor
134

drbh's avatar
drbh committed
135
136
137
    # Adapter metadata for each request
    adapter_meta: AdapterBatchMetadata

138
    # Number of blocks in this batch
139
    num_blocks: int
140
141
    # Maximum number of blocks
    max_blocks: int
142

143
144
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
145
            id=self.batch_id,
146
            request_ids=[r.id for r in self.requests],
147
            size=len(self),
148
            max_tokens=self.num_blocks * BLOCK_SIZE,
149
150
151
        )

    @classmethod
Daniël de Kok's avatar
Daniël de Kok committed
152
153
154
    def batch_tokenized_inputs(
        cls, requests: Iterable[generate_pb2.Request], tokenizer
    ):
155
156
        batch_inputs = []
        max_truncation = 0
157
        for r in requests:
Daniël de Kok's avatar
Daniël de Kok committed
158
            batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
159
160
161
162
163
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]
164
        return batch_tokenized_inputs
165

drbh's avatar
drbh committed
166
167
168
169
170
171
172
173
174
    @classmethod
    def from_tokenized(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        batch_tokenized_inputs,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
175
        sliding_window = get_sliding_windows()
176
        position_ids = []
177
        cu_seqlen_prefill = [0]
178
179
        start_slots = []
        slot_indices = []
180
        prefill_cache_indices = []
181
182

        input_lengths = []
183
184
        prefix_offsets = []
        read_offsets = []
185
        all_input_ids = []
186
        requests_idx_mapping = {}
187

188
189
190
191
192
193
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

194
        next_token_chooser_parameters = []
195
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
196
        top_n_tokens = []
197

drbh's avatar
drbh committed
198
199
200
        adapter_indices_list = []
        adapter_set = set()

201
202
        # Cumulative length
        cumulative_length = 0
203
        cumulative_max_length = 0
204
        prefill_out_cumulative_length = 0
205

206
        num_blocks = 0
207
        max_seqlen = 0
208
        max_length = 0
209
        max_blocks = 0
210

211
212
213
        block_tables = []
        slots = []

214
        # Parse batch
215
216
217
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
218
219
220
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

221
            tokenized_input = tokenized_input[-r.truncate :]
222
223
224
225
226
            if (
                tokenized_input[0] == tokenizer.bos_token_id
                and tokenized_input[1] == tokenizer.bos_token_id
            ):
                tokenized_input = tokenized_input[1:]
227

228
229
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
230

231
            prefix_offsets.append(input_length - 5)
232
            read_offsets.append(input_length)
233

234
            all_input_ids.append(tokenized_input)
235
236

            # Position ids
237
238
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
239
240

            # Add cumulative lengths of all previous inputs
241
            cu_seqlen_prefill.append(cumulative_length + input_length)
242

243
            next_token_chooser_parameters.append(r.parameters)
244

245
246
247
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
248
            max_new_tokens = stopping_criteria.max_new_tokens
249
            stopping_criterias.append(stopping_criteria)
Nicolas Patry's avatar
Nicolas Patry committed
250
            top_n_tokens.append(r.top_n_tokens)
251

Nicolas Patry's avatar
Nicolas Patry committed
252
253
            ADAPTER_TO_INDEX = get_adapter_to_index()
            adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
drbh's avatar
drbh committed
254
255
256
            adapter_indices_list.append(torch.full((input_length,), adapter_index))
            adapter_set.add(adapter_index)

257
258
            # Paged attention
            # Remove one as the first token des not have a past
Nicolas Patry's avatar
Nicolas Patry committed
259
            speculative_length = get_speculate()
drbh's avatar
drbh committed
260
            speculative_length = 0 if speculative_length is None else speculative_length
Nicolas Patry's avatar
Nicolas Patry committed
261
            total_tokens = input_length + max_new_tokens - 1 + speculative_length
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280

            # blocks and slots can be empty (for example in warmup)
            if not r.blocks:
                needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
                request_blocks = [
                    b for b in range(num_blocks, num_blocks + needed_blocks)
                ]
                request_slots = [
                    s
                    for b in request_blocks
                    for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
                ]
            else:
                request_blocks = r.blocks
                request_slots = r.slots

            block_tables.append(request_blocks)
            slots.extend(request_slots[:total_tokens])
            num_blocks += len(request_blocks)
281
282
283
284
285
286
287
288
289
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

290
291
292
293
294
295
296
297
298
            # Create tensor to slice into the kv tensor in prefill
            if sliding_window is not None:
                request_prefill_cache_indices = torch.arange(
                    cumulative_length + max(0, input_length - sliding_window),
                    cumulative_length + input_length,
                    dtype=torch.int64,
                )
                prefill_cache_indices.append(request_prefill_cache_indices)

299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

319
320
            # Update
            cumulative_length += input_length
321
322
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
323
            max_blocks = max(max_blocks, len(request_blocks))
OlivierDehaene's avatar
OlivierDehaene committed
324
325
326
            max_length = max(
                max_length, input_length + max_new_tokens + speculative_length
            )
327

drbh's avatar
drbh committed
328
329
330
331
        adapter_indices = torch.cat(adapter_indices_list).to(
            dtype=torch.int64, device=device
        )

332
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
drbh's avatar
drbh committed
333
            next_token_chooser_parameters, dtype, device, tokenizer
334
        )
335
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
336
337
338
339
340
341
342

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
343

344
345
346
347
348
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

349
350
351
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
352
            slot_indices = torch.cat(slot_indices)
353
354
            if sliding_window is not None:
                prefill_cache_indices = torch.cat(prefill_cache_indices)
355
356
357
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
358
            slot_indices = slot_indices[0]
359
360
            if sliding_window is not None:
                prefill_cache_indices = prefill_cache_indices[0]
361

362
363
        cu_seqlen_prefill = torch.tensor(
            cu_seqlen_prefill, device=device, dtype=torch.int32
364
365
366
        )
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
367
368
369
        prefill_cache_indices = (
            prefill_cache_indices.to(device) if sliding_window is not None else None
        )
370
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
371
372
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
373
        )
374

drbh's avatar
drbh committed
375
376
377
378
379
        adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
        adapter_segments = torch.tensor(
            adapter_segments, dtype=torch.int32, device=device
        )

380
381
        if all_prefill_logprobs:
            prefill_head_indices = None
382
            prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
383
        elif no_prefill_logprobs:
384
            prefill_head_indices = cu_seqlen_prefill[1:] - 1
385
386
387
388
389
390
391
392
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )
Nicolas Patry's avatar
Nicolas Patry committed
393
394
395
        top_n_tokens_tensor = torch.tensor(
            top_n_tokens, device=device, dtype=torch.int64
        )
396

397
398
399
400
401
402
403
404
        slots = torch.tensor(slots, dtype=torch.int64, device=device)
        block_tables_tensor = torch.zeros(
            (len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
        )
        for i, request_blocks in enumerate(block_tables):
            block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
        block_tables_tensor = block_tables_tensor.to(device)

405
406
407
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
408
            requests_idx_mapping=requests_idx_mapping,
409
410
            input_ids=input_ids,
            position_ids=position_ids,
411
            cu_seqlen_prefill=cu_seqlen_prefill,
412
            prefill_cache_indices=prefill_cache_indices,
413
414
            start_slots=start_slots,
            slot_indices=slot_indices,
415
416
417
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
418
            max_seqlen=max_seqlen,
419
420
421
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
422
            input_lengths=input_lengths,
423
            input_lengths_tensor=input_lengths_tensor,
424
425
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
426
            all_input_ids=all_input_ids,
427
428
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
429
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
430
431
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
432
            num_blocks=num_blocks,
433
            max_blocks=max_blocks,
drbh's avatar
drbh committed
434
435
436
437
438
439
            adapter_meta=AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            ),
Nicolas Patry's avatar
Nicolas Patry committed
440
            speculative_ids=None,
441
442
        )

443
444
445
446
447
448
449
450
451
452
453
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "FlashCausalLMBatch":
        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
        return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)

454
    @tracer.start_as_current_span("filter")
455
456
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
457
458
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
459
        if len(request_ids) == len(self):
460
461
            return self

462
        device = self.input_ids.device
463

464
465
466
        # New values after filtering
        requests_idx_mapping = {}

467
468
469
        # Used to index into tensors
        indices = []

470
471
472
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
473
474
        )

475
        # Create on CPU to only move to GPU once instead of at every copy
476
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
477
478
        max_seqlen = 0

479
        requests = []
480
481
        start_slots = []
        block_tables = []
482
483
        all_input_ids = []

484
        input_lengths = []
485
486
        prefix_offsets = []
        read_offsets = []
487

488
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
489
        top_n_tokens = []
drbh's avatar
drbh committed
490
        adapter_set = set()
491

492
        num_blocks = 0
493
494
495
496
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

497
498
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
499
            indices.append(idx)
500
501
502
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
503
504
505
506

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
507

508
509
510
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
511
512
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
513

514
515
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
516

Nicolas Patry's avatar
Nicolas Patry committed
517
518
            top_n_tokens.append(self.top_n_tokens[idx])

Nicolas Patry's avatar
Nicolas Patry committed
519
520
            ADAPTER_TO_INDEX = get_adapter_to_index()
            adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
drbh's avatar
drbh committed
521
522
            adapter_set.add(adapter_index)

523
            remaining_tokens = (
524
525
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
526

527
            request_block_table = self.block_tables[idx]
528
            num_blocks += len(request_block_table)
529
530
531
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

532
            # Copy to tensor (CPU)
533
            slot_indices[i] = cumulative_max_length + request_input_length - 1
534
535

            # Set slice
536
537
538
539
540
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
541
542
543
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
544

545
546
            max_blocks = max(max_blocks, len(request_block_table))

547
548
549
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
drbh's avatar
drbh committed
550
        adapter_indices = self.adapter_meta.adapter_indices[indices]
551
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
552
553
554
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
555
        next_token_chooser = self.next_token_chooser.filter(indices)
Nicolas Patry's avatar
Nicolas Patry committed
556
        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
OlivierDehaene's avatar
OlivierDehaene committed
557
558
559
        speculative_ids = (
            self.speculative_ids[indices] if self.speculative_ids is not None else None
        )
560
561

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
562

563
        # Move to GPU now that we have the whole tensor
564
        slot_indices = slot_indices.to(device)
565

drbh's avatar
drbh committed
566
567
568
569
570
        adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
        adapter_segments = torch.tensor(
            adapter_segments, dtype=torch.int32, device=device
        )

571
        return type(self)(
572
573
574
575
576
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
577
            cu_seqlen_prefill=None,
578
            prefill_cache_indices=None,
579
580
581
582
583
            start_slots=start_slots,
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
584
            max_seqlen=max_seqlen,
585
586
587
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
588
            input_lengths=input_lengths,
589
            input_lengths_tensor=input_lengths_tensor,
590
591
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
592
593
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
594
            next_token_chooser=next_token_chooser,
595
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
596
597
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
598
            num_blocks=num_blocks,
599
            max_blocks=max_blocks,
Nicolas Patry's avatar
Nicolas Patry committed
600
            speculative_ids=speculative_ids,
drbh's avatar
drbh committed
601
602
603
604
605
606
            adapter_meta=AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            ),
607
608
609
610
611
612
613
614
615
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

616
        num_blocks = 0
617
618
619
620
621
622
623
624
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
625
            num_blocks += b.num_blocks
OlivierDehaene's avatar
OlivierDehaene committed
626
627
628
            speculative_length = (
                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
            )
629
630
631
632
633
634
635
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
Nicolas Patry's avatar
Nicolas Patry committed
636
                    + speculative_length
637
638
639
640
641
642
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
643
644
645

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
646
647
648
649
650
651
652
653
654
655
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
656
        )
Nicolas Patry's avatar
Nicolas Patry committed
657
658
659
        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
            total_batch_size,
        )
drbh's avatar
drbh committed
660
661
662
663
664
665
666
667
        total_indices_size = sum(
            b.adapter_meta.adapter_indices.shape[0] for b in batches
        )
        adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
            total_indices_size
        )
        adapter_set = set()
        adapter_segment_builder = SegmentConcatBuilder()
668

669
670
        start_slots = []
        block_tables = []
671
672
673
        all_input_ids = []

        input_lengths = []
674
675
        prefix_offsets = []
        read_offsets = []
676

677
        next_token_chooser_parameters = []
678
        fsm_grammar_states = []
679
        stopping_criterias = []
Nicolas Patry's avatar
Nicolas Patry committed
680
        top_n_tokens = []
681

682
        # Cumulative length
683
        cumulative_batch_size = 0
684
        cumulative_slots = 0
drbh's avatar
drbh committed
685
        cumulative_adapter_indices_size = 0
686
687
688

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
689
690
691
692
693
694
695
696

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

697
698
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
699
700
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
701
702
703
704

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
705
706
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
707
            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
708
            slots[slots_start_index:slots_end_index] = batch.slots
709

drbh's avatar
drbh committed
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
            # Copy over adapter indices
            adapter_start_index = cumulative_adapter_indices_size
            adapter_end_index = (
                cumulative_adapter_indices_size
                + batch.adapter_meta.adapter_indices.shape[0]
            )
            adapter_indices[adapter_start_index:adapter_end_index] = (
                batch.adapter_meta.adapter_indices
            )
            cumulative_adapter_indices_size = adapter_end_index
            adapter_set.update(batch.adapter_meta.adapter_set)
            adapter_segment_builder.concat(
                batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
            )

725
726
727
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
728

729
730
731
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
732

733
734
735
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
736
737
            all_input_ids.extend(batch.all_input_ids)

738
            input_lengths.extend(batch.input_lengths)
739
740
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
741

742
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
743
            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
744
745
            stopping_criterias.extend(batch.stopping_criterias)

Nicolas Patry's avatar
Nicolas Patry committed
746
747
            top_n_tokens.extend(batch.top_n_tokens)

748
            # Update
749
            cumulative_batch_size += len(batch)
750
            cumulative_slots += len(batch.slots)
751

752
        start_slots = torch.concat(start_slots)
753

754
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
755
756
757
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
drbh's avatar
drbh committed
758
            tokenizer=batches[0].next_token_chooser.tokenizer,
759
            fsm_grammar_states=fsm_grammar_states,
760
761
        )

OlivierDehaene's avatar
OlivierDehaene committed
762
763
764
765
766
        speculative_ids = (
            torch.cat([b.speculative_ids for b in batches], dim=0)
            if batches[0].speculative_ids is not None
            else None
        )
Nicolas Patry's avatar
Nicolas Patry committed
767

drbh's avatar
drbh committed
768
769
        adapter_segments, adapter_segment_indices = adapter_segment_builder.build()

770
        return cls(
771
772
            batch_id=batches[0].batch_id,
            requests=requests,
773
            requests_idx_mapping=requests_idx_mapping,
774
775
            input_ids=input_ids,
            position_ids=position_ids,
776
            cu_seqlen_prefill=None,
777
            prefill_cache_indices=None,
778
779
780
781
782
            start_slots=start_slots,
            slot_indices=slot_indices,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
783
            max_seqlen=max_seqlen,
784
785
786
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
787
            input_lengths=input_lengths,
788
            input_lengths_tensor=input_lengths_tensor,
789
790
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
791
792
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
793
            next_token_chooser=next_token_chooser,
794
            stopping_criterias=stopping_criterias,
Nicolas Patry's avatar
Nicolas Patry committed
795
796
            top_n_tokens=top_n_tokens,
            top_n_tokens_tensor=top_n_tokens_tensor,
797
            num_blocks=num_blocks,
798
            max_blocks=max_blocks,
OlivierDehaene's avatar
OlivierDehaene committed
799
            speculative_ids=speculative_ids,
drbh's avatar
drbh committed
800
801
802
803
804
805
            adapter_meta=AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_segment_indices,
            ),
806
807
808
809
810
811
        )

    def __len__(self):
        return len(self.requests)


812
813
814
815
816
817
818
819
820
821
822
823
ADAPTER_LAYERS = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}


824
825
826
class FlashCausalLM(Model):
    def __init__(
        self,
drbh's avatar
drbh committed
827
        model_id: str,
828
829
830
831
832
833
834
835
836
837
838
839
        model_class,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        speculator: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
        lora_adapter_ids: Optional[list] = [],
        tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
        config_class: PreTrainedTokenizerBase = AutoConfig,
        default_dtype=torch.float16,
        aliases=None,
        # Used for Santacoder override of config
840
841
842
        num_kv_heads: Optional[int] = None,
        # Deepseek V2 uses different QK and V dims.
        head_size: Optional[int] = None,
843
        skip_special_tokens: bool = True,
844
    ):
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = default_dtype if dtype is None else dtype
        elif SYSTEM == "ipex":
            if hasattr(torch, "xpu") and torch.xpu.is_available():
                device = torch.device(f"xpu:{rank}")
                dtype = default_dtype if dtype is None else dtype
            else:
                device = torch.device("cpu")
                # Float16 doesn't exist on target.
                dtype = torch.bfloat16 if dtype is None else dtype
        else:
            raise NotImplementedError(f"{model_class} is only available on GPU")

        tokenizer = tokenizer_class.from_pretrained(
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )
        try:
            generation_config = GenerationConfig.from_pretrained(
                model_id, revision=revision, trust_remote_code=trust_remote_code
            )
            if isinstance(generation_config.eos_token_id, (list, set)):
                # TODO Huge hack
                tokenizer._eos_token_ids = set(generation_config.eos_token_id)
        except Exception:
            pass

        config = config_class.from_pretrained(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        config.quantize = quantize
        config.speculator = speculator

        torch.distributed.barrier(group=self.process_group)

885
        weights_loader = get_loader(quantize, model_id, revision)
886
887
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(
888
889
890
891
892
893
            filenames,
            device,
            dtype,
            process_group=self.process_group,
            aliases=aliases,
            weights_loader=weights_loader,
894
895
896
897
898
899
900
901
902
903
        )

        prefix = ""
        model = model_class(prefix, config, weights)
        torch.distributed.barrier(group=self.process_group)

        # VLM models define the config we care about in their text_config
        text_config = getattr(config, "text_config", None)
        if text_config is not None:
            config = text_config
904
905
906
907
908
909

        if getattr(config, "sliding_window", None) is not None:
            set_sliding_window(config.sliding_window)
        else:
            config.sliding_window = None

910
911
912
        self.num_layers = config.num_hidden_layers
        # Validation is done in the model itself
        if num_kv_heads is None:
913
914
            num_kv_heads = getattr(config, "num_key_value_heads", None)
            # GPT-2 workaround
915
            if num_kv_heads is None:
916
917
918
                num_kv_heads = getattr(config, "n_head", None)
        if num_kv_heads is None:
            raise ValueError("Cannot get the number of key/value heads")
919
920
921
922
923
924
        self.num_kv_heads = (
            num_kv_heads // self.process_group.size()
            if num_kv_heads > 1
            else num_kv_heads
        )
        assert self.num_kv_heads > 0
925
926

        if head_size is None:
Nicolas Patry's avatar
Nicolas Patry committed
927
928
929
930
931
932
            # Some models use GQA and different sizes for o_proj
            # and q_proj, that allows for that.
            if hasattr(config, "head_dim"):
                self.head_size = config.head_dim
            else:
                self.head_size = config.hidden_size // config.num_attention_heads
933
934
        else:
            self.head_size = head_size
935

936
        self.cuda_graphs = {}
937
        self.kv_cache = []
938

939
        super().__init__(
drbh's avatar
drbh committed
940
            model_id=model_id,
941
            model=model,
942
943
944
945
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
946
947
            rank=rank,
            world_size=world_size,
948
            sliding_window=config.sliding_window,
949
950
951
952
953
954
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
    def max_past(self) -> int:
        return getattr(self.model, "max_past", None)

    def init_kv_cache(
        self,
        num_blocks: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        self.kv_cache = []
        empty_cache()

        element_size = torch.tensor([], dtype=dtype).element_size()
Wang, Yi's avatar
Wang, Yi committed
971
972
973
974
        if SYSTEM == "ipex" and device.type == "xpu":
            x = 1
        else:
            x = BLOCK_SIZE // element_size
975

976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
        if FLASH_DECODING:
            self.kv_cache = [
                (
                    torch.empty(
                        (num_blocks, BLOCK_SIZE, num_heads, head_size),
                        dtype=dtype,
                        device=device,
                    ),
                    torch.empty(
                        (num_blocks, BLOCK_SIZE, num_heads, head_size),
                        dtype=dtype,
                        device=device,
                    ),
                )
                for _ in range(num_layers)
            ]
        elif SYSTEM == "ipex" and device == torch.device("cpu"):
Wang, Yi's avatar
Wang, Yi committed
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
            self.kv_cache = [
                (
                    torch.empty(
                        (num_blocks, num_heads, BLOCK_SIZE, head_size),
                        dtype=dtype,
                        device=device,
                    ),
                    torch.empty(
                        (num_blocks, num_heads, BLOCK_SIZE, head_size),
                        dtype=dtype,
                        device=device,
                    ),
                )
                for _ in range(num_layers)
            ]
        else:
            self.kv_cache = [
                (
                    torch.empty(
                        (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
                        dtype=dtype,
                        device=device,
                    ),
                    torch.empty(
                        (num_blocks, num_heads, head_size, BLOCK_SIZE),
                        dtype=dtype,
                        device=device,
                    ),
                )
                for _ in range(num_layers)
            ]
1024

1025
1026
1027
    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
        input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
1028
        slots = torch.arange(bs, dtype=torch.int64, device=self.device)
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
        input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
        block_tables = (
            torch.arange(max_bt, dtype=torch.int32, device=self.device)
            .repeat(bs)
            .reshape((bs, max_bt))
        )

        self.cuda_graphs[bs] = {
            "input_ids": input_ids,
            "position_ids": position_ids,
1039
            "kv_cache": self.kv_cache,
1040
1041
1042
1043
            "block_tables": block_tables,
            "slots": slots,
            "input_lengths": input_lengths,
        }
1044
        input_lengths_ = Seqlen(input_lengths=input_lengths)
1045
1046
1047
1048
1049
1050
1051
1052
1053
        graph = torch.cuda.CUDAGraph()
        self.cuda_graphs[bs]["graph"] = graph

        torch.cuda.synchronize()
        # Run once outside to warmup
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=None,
1054
            kv_cache=self.kv_cache,
1055
1056
            block_tables=block_tables,
            slots=slots,
1057
            input_lengths=input_lengths_,
1058
            max_s=max_s,
1059
            prefill_cache_indices=None,
1060
1061
1062
1063
1064
            lm_head_indices=None,
        )
        torch.cuda.synchronize()

        with torch.cuda.graph(graph, pool=MEM_POOL):
1065
            input_lengths = Seqlen(input_lengths=input_lengths)
1066
            logits, speculative_logits = self.model.forward(
1067
1068
1069
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
1070
                kv_cache=self.kv_cache,
1071
1072
1073
1074
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
1075
                prefill_cache_indices=None,
1076
1077
                lm_head_indices=None,
            )
1078
1079
            self.cuda_graphs[bs]["logits"] = logits
            self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
1080
1081
        torch.cuda.synchronize()

1082
    def warmup(self, batch: FlashCausalLMBatch):
1083
        # The warmup batch is the biggest batch we could ever receive
Nicolas Patry's avatar
Nicolas Patry committed
1084
1085
        empty_cache()

1086
        try:
1087
1088
            self.init_kv_cache(
                batch.num_blocks,
1089
1090
1091
1092
1093
1094
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
                self.dtype,
                self.device,
            )
1095
            max_bt = batch.max_blocks
1096
            max_s = max_bt * BLOCK_SIZE
fxmarty's avatar
fxmarty committed
1097
1098
1099

            if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
                torch.cuda.tunable.tuning_enable(False)
1100
            _, batch, _ = self.generate_token(batch)
OlivierDehaene's avatar
OlivierDehaene committed
1101
        except torch.cuda.OutOfMemoryError as e:
1102
            raise RuntimeError(
1103
1104
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
1105
            ) from e
1106

Nicolas Patry's avatar
Nicolas Patry committed
1107
        synchronize(self.device)
1108

1109
1110
        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
1111
1112
1113
1114
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

Nicolas Patry's avatar
Nicolas Patry committed
1115
        free_memory = get_free_memory(self.device, MEMORY_FRACTION)
drbh's avatar
drbh committed
1116
        batch_num_blocks = batch.num_blocks if batch is not None else 0
1117
1118

        num_blocks = (
1119
1120
            # Leave 5% for some wiggle room
            int((free_memory * 0.95) // total_cache_size)
1121
            # Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
drbh's avatar
drbh committed
1122
            + batch_num_blocks
1123
1124
        )

1125
        del batch
1126

1127
        self.init_kv_cache(
1128
1129
1130
1131
1132
1133
1134
1135
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
            self.dtype,
            self.device,
        )

fxmarty's avatar
fxmarty committed
1136
1137
1138
1139
1140
        if SYSTEM == "rocm":
            if (
                os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
                or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
            ):
1141
1142
                torch.cuda.tunable.enable()

fxmarty's avatar
fxmarty committed
1143
1144
1145
1146
1147
1148
1149
1150
                if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
                    torch.cuda.tunable.tuning_enable(True)

                if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None:
                    tuning_sequences = [
                        int(val)
                        for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
                    ]
1151
                elif CUDA_GRAPHS is not None:
fxmarty's avatar
fxmarty committed
1152
                    tuning_sequences = CUDA_GRAPHS
1153
1154
1155
                else:
                    # For seqlen = 1, we dispatch to LLMM1 kernel.
                    tuning_sequences = [2, 3, 4, 5, 6, 7]
fxmarty's avatar
fxmarty committed
1156
1157
1158

                tunableop_filepath = os.path.join(
                    HUGGINGFACE_HUB_CACHE,
Nicolas Patry's avatar
Nicolas Patry committed
1159
                    f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
fxmarty's avatar
fxmarty committed
1160
1161
                )

1162
1163
1164
                log_master(
                    logger.info,
                    f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
fxmarty's avatar
fxmarty committed
1165
1166
1167
                )

                if os.path.isfile(tunableop_filepath):
1168
1169
1170
                    log_master(
                        logger.info,
                        f"The file {tunableop_filepath} already exists and will be reused.",
fxmarty's avatar
fxmarty committed
1171
1172
1173
1174
1175
1176
                    )
                    torch.cuda.tunable.read_file(tunableop_filepath)

                os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)

                for seqlen in tuning_sequences:
1177
                    log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
fxmarty's avatar
fxmarty committed
1178
1179
1180
1181
                    self.tunableop_warmup(seqlen)
                    torch.cuda.tunable.write_file(tunableop_filepath)
                torch.cuda.tunable.tuning_enable(False)
            else:
1182
1183
1184
                log_master(
                    logger.info,
                    "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
fxmarty's avatar
fxmarty committed
1185
1186
                )

1187
        if CUDA_GRAPHS:
1188
            try:
1189
1190
1191
                log_master(
                    logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
                )
1192
                # Warmup cuda graphs
1193
                for bs in CUDA_GRAPHS:
1194
1195
                    if self.speculate is None or self.speculate + 1 <= bs:
                        self.cuda_graph_warmup(bs, max_s, max_bt)
OlivierDehaene's avatar
OlivierDehaene committed
1196
            except torch.cuda.OutOfMemoryError:
1197
                logger.exception(f"Decode cuda graph warmup failed")
1198
        else:
1199
1200
1201
            log_master(
                logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
            )
1202

1203
        return int(num_blocks * BLOCK_SIZE)
1204

fxmarty's avatar
fxmarty committed
1205
1206
1207
1208
1209
    def tunableop_warmup(self, seqlen: int):
        input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
        position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
        slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)

fxmarty's avatar
fxmarty committed
1210
1211
        # Dummy value, some models (starcoder2) don't accept `None`.
        input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
1212
        input_lengths = Seqlen(input_lengths=input_lengths)
fxmarty's avatar
fxmarty committed
1213

fxmarty's avatar
fxmarty committed
1214
1215
1216
1217
1218
1219
1220
        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
        self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlen_prefill=torch.tensor(
                [0, seqlen], device=self.device, dtype=torch.int32
            ),
1221
            kv_cache=self.kv_cache,
fxmarty's avatar
fxmarty committed
1222
            block_tables=None,
fxmarty's avatar
fxmarty committed
1223
            input_lengths=input_lengths,
fxmarty's avatar
fxmarty committed
1224
1225
1226
            slots=slots,
            max_s=seqlen,
            lm_head_indices=None,
1227
            prefill_cache_indices=None,
fxmarty's avatar
fxmarty committed
1228
1229
        )

1230
    def forward(
drbh's avatar
drbh committed
1231
        self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
1232
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1233
        # Model Forward
Nicolas Patry's avatar
Nicolas Patry committed
1234
        if batch.speculative_ids is not None:
OlivierDehaene's avatar
OlivierDehaene committed
1235
1236
1237
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
1238
            kv_cache = self.kv_cache
OlivierDehaene's avatar
OlivierDehaene committed
1239
1240
1241
1242
1243
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
1244
1245
1246

            speculative_ids = batch.speculative_ids

OlivierDehaene's avatar
OlivierDehaene committed
1247
            B, speculative_length = speculative_ids.shape
Nicolas Patry's avatar
Nicolas Patry committed
1248
            new_length = speculative_length + 1
OlivierDehaene's avatar
OlivierDehaene committed
1249
1250
1251
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1252
1253
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
OlivierDehaene's avatar
OlivierDehaene committed
1254
1255
1256
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1257
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
OlivierDehaene's avatar
OlivierDehaene committed
1258
1259
1260
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
1261
1262

            # Add Copy the block tables for all members
OlivierDehaene's avatar
OlivierDehaene committed
1263
1264
1265
1266
1267
1268
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
Nicolas Patry's avatar
Nicolas Patry committed
1269
1270
1271
1272
1273
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
OlivierDehaene's avatar
OlivierDehaene committed
1274
1275
1276
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
1277
            kv_cache = self.kv_cache
OlivierDehaene's avatar
OlivierDehaene committed
1278
1279
1280
1281
1282
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
Nicolas Patry's avatar
Nicolas Patry committed
1283

1284
1285
1286
1287
1288
1289
        if cu_seqlen_prefill is None and self.max_past() is not None:
            # In decode, not prefill, we're actually overwriting the KV-cache
            # in a circular buffer mode.
            # This makes sure the max_s for the decode pass is correct.
            max_s = min(self.max_past(), max_s)

1290
        bs = input_ids.shape[0]
OlivierDehaene's avatar
OlivierDehaene committed
1291
1292
1293
1294
1295
1296
1297
1298
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None

        if cu_seqlen_prefill is not None or cuda_graph is None:
1299
            input_lengths = Seqlen(input_lengths=input_lengths)
1300
            logits, speculative_logits = self.model.forward(
1301
1302
1303
1304
1305
1306
1307
1308
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=cu_seqlen_prefill,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
1309
                prefill_cache_indices=batch.prefill_cache_indices,
1310
                lm_head_indices=lm_head_indices,
drbh's avatar
drbh committed
1311
                adapter_data=adapter_data,
1312
            )
1313
1314
1315
            if batch.prefill_cache_indices is not None:
                batch.prefill_cache_indices = None
            return logits, speculative_logits
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
        cuda_graph["block_tables"][
            : block_tables.shape[0], : block_tables.shape[1]
        ] = block_tables
        cuda_graph["slots"].fill_(-1)
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths

        # Replay the graph
        cuda_graph["graph"].replay()
        # Slice output to the correct shape
1332
1333
1334
1335
1336
1337
1338
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits
1339
1340
1341
1342

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
1343
1344
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
        start = time.time_ns()
1345
        prefill = batch.cu_seqlen_prefill is not None
1346
        prefill_logprobs = batch.prefill_next_token_indices is not None
1347

drbh's avatar
drbh committed
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
        # Update adapter indices for speculative tokens (if present)
        adapter_meta = batch.adapter_meta
        if batch.speculative_ids is not None:
            B, speculative_length = batch.speculative_ids.shape
            new_length = speculative_length + 1
            adapter_indices = (
                adapter_meta.adapter_indices.unsqueeze(-1)
                .expand(B, new_length)
                .reshape(-1)
            )
            adapter_segments = adapter_meta.adapter_segments * new_length
            adapter_meta = AdapterBatchMetadata(
                adapter_indices=adapter_indices,
                adapter_set=adapter_meta.adapter_set,
                adapter_segments=adapter_segments,
                segment_indices=adapter_meta.segment_indices,
            )

        # Assign pointers to adapter weights
        # TODO(travis): don't update this if indices haven't changed
        adapter_data = AdapterBatchData.from_meta(
            adapter_meta,
            self.layer_to_adapter_weights,
            prefill,
            batch.prefill_head_indices,
        )

        out, speculative_logits = self.forward(batch, adapter_data)
1376

1377
1378
        if prefill:
            next_token_logits = (
1379
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
1380
            )
Nicolas Patry's avatar
Nicolas Patry committed
1381
1382
            if speculative_logits is not None:
                speculative_logits = (
OlivierDehaene's avatar
OlivierDehaene committed
1383
1384
1385
                    speculative_logits[batch.prefill_next_token_indices]
                    if prefill_logprobs
                    else speculative_logits
Nicolas Patry's avatar
Nicolas Patry committed
1386
                )
drbh's avatar
drbh committed
1387
1388
1389
1390
            next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
                len(batch)
            )

1391
1392
        else:
            next_token_logits = out
drbh's avatar
drbh committed
1393
            next_adapter_indices = batch.adapter_meta.adapter_indices
1394

Nicolas Patry's avatar
Nicolas Patry committed
1395
        speculate = get_speculate()
OlivierDehaene's avatar
OlivierDehaene committed
1396
1397
1398
1399
1400
1401
1402
1403
1404
        (
            next_input_ids,
            next_token_logprobs,
            logprobs,
            accepted_ids,
            speculative_ids,
        ) = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen],
            next_token_logits,
Nicolas Patry's avatar
Nicolas Patry committed
1405
            speculate,
OlivierDehaene's avatar
OlivierDehaene committed
1406
1407
            batch.speculative_ids,
            speculative_logits,
1408
1409
        )

Nicolas Patry's avatar
Nicolas Patry committed
1410
        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
Nicolas Patry's avatar
Nicolas Patry committed
1411
            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1412
1413
        )

1414
        if prefill:
1415
            if len(batch) > 1 and prefill_logprobs:
1416
1417
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
1418
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
1419
1420

            next_position_ids = batch.position_ids.new_empty(len(batch))
1421
1422
1423
            batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
            # We do not need cu_seqlen_prefill anymore
            batch.cu_seqlen_prefill = None
1424
1425
1426
1427
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

1428
1429
1430
1431
1432
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
1433
        stopped = True
1434
1435

        # Zipped iterator
OlivierDehaene's avatar
OlivierDehaene committed
1436
        iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
1437

1438
1439
1440
1441
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

1442
        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1443
        index = 0
OlivierDehaene's avatar
OlivierDehaene committed
1444
        for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
1445
            # Indexing metadata
1446
1447
1448
            start_index = cumulative_length
            end_index = cumulative_length + input_length

1449
            if prefill:
1450
1451
1452
1453
1454
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

1455
1456
1457
1458
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

drbh's avatar
drbh committed
1459
1460
1461
1462
1463
1464
                # Initialize adapter indices
                # In decode, we only have one token per row in the batch, so grab last index
                next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
                    end_index - 1
                ]

1465
1466
                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
1467
1468
                if prefill_logprobs:
                    if len(batch) > 1:
drbh's avatar
drbh committed
1469
1470
1471
                        prefill_tokens_indices[out_start_index : out_end_index - 1] = (
                            batch.input_ids[start_index + 1 : start_index + out_length]
                        )
1472
1473
1474
1475
1476
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
1477

Nicolas Patry's avatar
Nicolas Patry committed
1478
1479
1480
            for j in range(n_accepted_ids):
                batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
                index += 1
1481
1482
1483

            cumulative_length += input_length

drbh's avatar
drbh committed
1484
        # Update values
Nicolas Patry's avatar
Nicolas Patry committed
1485
1486
1487
1488
1489
        batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
        batch.speculative_ids = speculative_ids
        batch.position_ids = next_position_ids + accepted_ids
        batch.input_lengths_tensor += accepted_ids
        batch.slot_indices += accepted_ids
drbh's avatar
drbh committed
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
        batch.adapter_meta.adapter_indices = next_adapter_indices

        if prefill:
            # adjust segment lengths to account for all request lengths being 1 during decoding
            adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
            batch.adapter_meta.adapter_segments = torch.tensor(
                adapter_segments,
                dtype=torch.int32,
                device=batch.adapter_meta.adapter_segments.device,
            )
1500

1501
        if prefill and prefill_logprobs:
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
Nicolas Patry's avatar
Nicolas Patry committed
1512
        next_token_ids = next_input_ids.tolist()
1513
1514
        accepted_ids = accepted_ids.tolist()
        start_decode = time.time_ns()
1515
1516
1517
1518
1519

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
1520
1521
            batch.prefix_offsets,
            batch.read_offsets,
1522
1523
            batch.stopping_criterias,
            batch.all_input_ids,
1524
1525
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
Nicolas Patry's avatar
Nicolas Patry committed
1526
            batch.top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1527
            accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1528
1529
            batch_top_token_ids,
            batch_top_token_logprobs,
1530
1531
1532
        )

        # For each member of the batch
Nicolas Patry's avatar
Nicolas Patry committed
1533
        index = 0
1534
1535
1536
        for i, (
            request,
            input_length,
1537
1538
            prefix_offset,
            read_offset,
1539
1540
            stopping_criteria,
            all_input_ids,
1541
1542
            do_sample,
            seed,
Nicolas Patry's avatar
Nicolas Patry committed
1543
            top_n_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1544
            n_accepted_ids,
Nicolas Patry's avatar
Nicolas Patry committed
1545
1546
            top_token_ids,
            top_token_logprobs,
1547
        ) in enumerate(iterator):
1548
            # Append next token to all tokens
Nicolas Patry's avatar
Nicolas Patry committed
1549
1550
1551
            next_token_texts = []
            left = 0

1552
            if n_accepted_ids > 1:
1553
                log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
1554

Nicolas Patry's avatar
Nicolas Patry committed
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
            current_stopped = False
            for j in range(index, index + n_accepted_ids):
                # Generated token
                next_token_id = next_token_ids[j]
                all_input_ids.append(next_token_id)
                next_token_text, prefix_offset, read_offset = self.decode_token(
                    all_input_ids,
                    prefix_offset,
                    read_offset,
                )
                next_token_texts.append(next_token_text)
1566

Nicolas Patry's avatar
Nicolas Patry committed
1567
1568
1569
1570
                stop, reason = stopping_criteria(
                    next_token_id,
                    next_token_text,
                )
1571

Nicolas Patry's avatar
Nicolas Patry committed
1572
1573
1574
1575
1576
1577
1578
                if stop:
                    left = index + n_accepted_ids - j - 1
                    current_stopped = True
                    break
                else:
                    current_stopped = False
            stopped = stopped and current_stopped
1579

OlivierDehaene's avatar
OlivierDehaene committed
1580
1581
1582
1583
            _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
            _next_token_logprobs = next_token_logprobs[
                index : index + n_accepted_ids - left
            ]
Nicolas Patry's avatar
Nicolas Patry committed
1584
            index += n_accepted_ids
1585

1586
1587
1588
1589
1590
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
1591
1592
                    output_text, _, _ = self.decode_token(
                        all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
1593
1594
1595
1596
1597
1598
                        prefix_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens
                        - 1,
                        read_offset=len(all_input_ids)
                        - stopping_criteria.current_tokens,
                        skip_special_tokens=True,
1599
1600
                    )
                    generated_text = GeneratedText(
1601
1602
1603
1604
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
1605
1606
1607
1608
1609
                    )
                else:
                    generated_text = None

                # Prefill
1610
1611
1612
1613
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

1614
1615
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
1616
                        out_start_index : out_end_index - 1
1617
1618
1619
1620
1621
1622
1623
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
Nicolas Patry's avatar
Nicolas Patry committed
1624
1625

                    prefill_tokens = Tokens(
OlivierDehaene's avatar
OlivierDehaene committed
1626
1627
1628
1629
                        prefill_token_ids,
                        request_prefill_logprobs,
                        prefill_texts,
                        is_special=[],
1630
1631
1632
1633
                    )
                else:
                    prefill_tokens = None

Nicolas Patry's avatar
Nicolas Patry committed
1634
                if top_n_tokens > 0:
Nicolas Patry's avatar
Nicolas Patry committed
1635
                    all_top_tokens = []
drbh's avatar
drbh committed
1636
                    for top_token_ids, top_token_logprobs in zip(
1637
1638
                        top_token_ids, top_token_logprobs
                    ):
Nicolas Patry's avatar
Nicolas Patry committed
1639
1640
1641
1642
1643
1644
                        toptoken_texts = self.tokenizer.batch_decode(
                            top_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
                        special_toptokens = [
1645
1646
                            token_id in self.all_special_ids
                            for token_id in top_token_ids
Nicolas Patry's avatar
Nicolas Patry committed
1647
1648
1649
1650
1651
1652
1653
1654
1655
                        ]
                        top_tokens = Tokens(
                            top_token_ids,
                            top_token_logprobs,
                            toptoken_texts,
                            special_toptokens,
                        )
                        all_top_tokens.append(top_tokens)
                    top_tokens = all_top_tokens
Nicolas Patry's avatar
Nicolas Patry committed
1656
1657
1658
                else:
                    top_tokens = None

1659
1660
1661
                generation = Generation(
                    request.id,
                    prefill_tokens,
Nicolas Patry's avatar
Nicolas Patry committed
1662
1663
1664
1665
1666
1667
                    Tokens(
                        _next_token_ids,
                        _next_token_logprobs,
                        next_token_texts,
                        [nid in self.all_special_ids for nid in _next_token_ids],
                    ),
1668
                    generated_text,
Nicolas Patry's avatar
Nicolas Patry committed
1669
                    top_tokens,
1670
1671
                )

1672
                generations.append(generation)
1673

drbh's avatar
drbh committed
1674
1675
1676
            # accept each new token for this specific request since we may
            # have more than one new token per request with speculative decoding
            for next_token_id in _next_token_ids:
OlivierDehaene's avatar
OlivierDehaene committed
1677
1678
1679
                batch.next_token_chooser = (
                    batch.next_token_chooser.advance_grammar_single(i, next_token_id)
                )
drbh's avatar
drbh committed
1680

1681
            # Update values
1682
            batch.input_lengths[i] = input_length + n_accepted_ids
Nicolas Patry's avatar
Nicolas Patry committed
1683
1684
            if batch.input_lengths[i] > batch.max_seqlen:
                batch.max_seqlen = batch.input_lengths[i]
1685
1686
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
1687
1688
            batch.all_input_ids[i] = all_input_ids

1689
1690
        if stopped:
            # No need to return a batch if we know that all requests stopped
1691
1692
1693
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, None, (forward_ns, decode_ns)
1694

1695
1696
1697
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1698

1699
1700
1701
        forward_ns = start_decode - start
        decode_ns = time.time_ns() - start_decode
        return generations, batch, (forward_ns, decode_ns)