flash_causal_lm.py 29.3 KB
Newer Older
1
2
3
import torch
import torch.distributed

4
5
import numpy as np

6
7
8
9
10
from torch.nn import functional as F

from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
11
from typing import Optional, Tuple, List, Type, Union, Dict
12
13
14
15
16
17
18
19
20

from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
21
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
22
23
24
25
26
27
28
29

tracer = trace.get_tracer(__name__)


@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
30
31
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
32
33

    # Decoder values
34
35
36
    input_ids: torch.Tensor
    position_ids: torch.Tensor

37
    # cumulative sequence lengths
38
39
40
41
42
    cu_seqlens: torch.Tensor
    # cumulative query sequence lengths, only used in decode
    cu_seqlens_q: Optional[torch.Tensor]
    # past key values, only used in decode
    past_key_values: Optional[torch.Tensor]
43
44
    max_seqlen: int

45
46
47
48
49
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

50
51
    # All tokens
    all_input_ids: List[List[int]]
52
    all_input_ids_tensor: torch.Tensor
53
54
55

    # Lengths of all generations present in the batch
    input_lengths: List[int]
56
57
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
58
59

    # Generation helpers
60
    next_token_chooser: HeterogeneousNextTokenChooser
61
62
    stopping_criterias: List[StoppingCriteria]

63
64
65
    # Maximum number of tokens this batch will grow to
    max_tokens: int

66
67
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
68
            id=self.batch_id,
69
            request_ids=[r.id for r in self.requests],
70
71
            size=len(self),
            max_tokens=self.max_tokens,
72
73
74
75
76
77
78
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
79
        dtype: torch.dtype,
80
        device: torch.device,
81
    ) -> "FlashCausalLMBatch":
82
83
84
85
86
87
88
89
90
91
        batch_inputs = []
        max_truncation = 0
        for r in pb.requests:
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]

92
93
94
95
96
        position_ids = []
        cu_seqlens = [0]
        max_seqlen = 0

        input_lengths = []
97
98
        prefix_offsets = []
        read_offsets = []
99
        all_input_ids = []
100
        requests_idx_mapping = {}
101

102
103
104
105
106
107
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

108
        next_token_chooser_parameters = []
109
110
111
112
        stopping_criterias = []

        # Cumulative length
        cumulative_length = 0
113
        prefill_out_cumulative_length = 0
114

115
        max_tokens = 0
116
        max_length = 0
117

118
        # Parse batch
119
120
121
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
122
123
124
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

125
            tokenized_input = tokenized_input[-r.truncate :]
126

127
128
129
            input_length = len(tokenized_input)
            max_seqlen = max(max_seqlen, input_length)
            input_lengths.append(input_length)
130

131
            prefix_offsets.append(input_length - 5)
132
            read_offsets.append(input_length)
133

134
            all_input_ids.append(tokenized_input)
135
136

            # Position ids
137
138
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
139
140
141
142

            # Add cumulative lengths of all previous inputs
            cu_seqlens.append(cumulative_length + input_length)

143
            next_token_chooser_parameters.append(r.parameters)
144

145
146
147
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
148
            max_new_tokens = stopping_criteria.max_new_tokens
149
            stopping_criterias.append(stopping_criteria)
150

151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

171
172
            # Update
            cumulative_length += input_length
173
            max_tokens += input_length + max_new_tokens
174
175
176
177
178
179
180
181
182
183
184
185
            max_length = max(max_length, input_length + max_new_tokens)

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
            next_token_chooser_parameters, dtype, device
        )

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
186

187
188
189
190
191
192
193
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]

194
        # Create tensors on device
195
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
196
197
198
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )
199
        position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
200
201
        cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)

202
203
204
205
206
207
208
209
210
211
212
213
214
215
        if all_prefill_logprobs:
            prefill_head_indices = None
            prefill_next_token_indices = cu_seqlens[1:] - 1
        elif no_prefill_logprobs:
            prefill_head_indices = cu_seqlens[1:] - 1
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )

216
217
218
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
219
            requests_idx_mapping=requests_idx_mapping,
220
221
222
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
223
            cu_seqlens_q=None,
224
            max_seqlen=max_seqlen,
225
226
227
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
228
229
            past_key_values=None,
            input_lengths=input_lengths,
230
231
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
232
            all_input_ids=all_input_ids,
233
234
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
235
            stopping_criterias=stopping_criterias,
236
            max_tokens=max_tokens,
237
238
        )

239
    @tracer.start_as_current_span("filter")
240
241
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
242
243
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
244
        if len(request_ids) == len(self):
245
246
            return self

247
        single_request = len(request_ids) == 1
248

249
250
251
252
253
254
        # Cumulative length
        cumulative_length = 0

        # New values after filtering
        requests_idx_mapping = {}

255
256
257
        # Used to index into tensors
        indices = []

258
        # Create on CPU to only move to GPU once instead of at every copy
259
        cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32)
260
        cu_seqlens_q = self.cu_seqlens_q[: len(request_ids) + 1]
261
262
263
        max_seqlen = 0
        past_key_values = []

264
        requests = []
265
266
        all_input_ids = []

267
        input_lengths = []
268
269
        prefix_offsets = []
        read_offsets = []
270

271
272
        stopping_criterias = []

273
274
        max_tokens = 0

275
276
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
277
            indices.append(idx)
278
279
280
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
281
282
283
284

            # Get length
            request_input_length = self.input_lengths[idx]

285
286
            # Copy to tensor (CPU)
            cu_seqlens[i + 1] = cumulative_length + request_input_length
287
            max_seqlen = max(max_seqlen, request_input_length)
288

289
290
291
292
            # Slice from past
            past_key_values.append(
                self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]]
            )
293
294
295
296

            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
297
298
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
299

300
301
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
302
303

            cumulative_length += request_input_length
304
305
306
            max_tokens += request_input_length + (
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
307

308
309
        if single_request:
            # Preallocate tensor for bs = 1 case
310
            past_key_values = F.pad(
311
                past_key_values[0],
312
313
314
315
316
317
318
319
320
321
322
                (
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    stopping_criterias[0].max_new_tokens
                    - stopping_criterias[0].current_tokens,
                ),
323
            )
324
325
326
327
        else:
            # Cat all past
            past_key_values = torch.cat(past_key_values, dim=1)

328
329
330
331
332
333
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
        next_token_chooser = self.next_token_chooser.filter(indices)

334
335
        # Move to GPU now that we have the whole tensor
        cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
336

337
338
339
340
341
342
343
        return FlashCausalLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
344
            cu_seqlens_q=cu_seqlens_q,
345
            max_seqlen=max_seqlen,
346
347
348
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
349
350
            past_key_values=past_key_values,
            input_lengths=input_lengths,
351
352
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
353
354
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
355
            next_token_chooser=next_token_chooser,
356
            stopping_criterias=stopping_criterias,
357
            max_tokens=max_tokens,
358
359
360
361
362
363
364
365
366
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

367
368
        total_batch_size = sum([len(b) for b in batches])

369
        dtype = batches[0].past_key_values.dtype
370
371
372
373
        device = batches[0].input_ids.device

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
374
        cu_seqlens = [0]
375
376
377
        cu_seqlens_q = torch.arange(
            0, total_batch_size + 1, device=device, dtype=torch.int32
        )
378
379
380
        max_seqlen = 0
        past_key_values = []

381
382
383
        all_input_ids = []

        input_lengths = []
384
385
        prefix_offsets = []
        read_offsets = []
386

387
        next_token_chooser_parameters = []
388
389
        stopping_criterias = []

390
        # Cumulative length
391
392
        cumulative_batch_size = 0
        cumulative_length = 0
393
        max_tokens = 0
394
        max_length = 0
395
396
397

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
398
399
400
401
402
403
404
405

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

406
407
408
409
410
411
412
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids

413
414
415
            # Add cumulative lengths of all previous inputs
            cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
            max_seqlen = max(max_seqlen, batch.max_seqlen)
416

417
            if len(batch) != 1:
418
                past_key_values.append(batch.past_key_values)
419
            else:
420
421
422
423
424
                # past was pre-allocated for this batch
                # We need to slice to remove the padding
                past_key_values.append(
                    batch.past_key_values[:, : batch.input_lengths[0]]
                )
425
426
427

            all_input_ids.extend(batch.all_input_ids)

428
            input_lengths.extend(batch.input_lengths)
429
430
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
431

432
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
433
434
435
436
            stopping_criterias.extend(batch.stopping_criterias)

            # Update
            cumulative_length += batch.cu_seqlens[-1]
437
            cumulative_batch_size += len(batch)
438
            max_tokens += batch.max_tokens
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        batch.input_lengths, batch.stopping_criterias
                    )
                ),
            )

        all_input_ids_tensor = torch.zeros(
            (total_batch_size, max_length), dtype=torch.int64, device=device
        )

        cumulative_batch_size = 0
        for i, batch in enumerate(batches):
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)

            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]

            cumulative_batch_size += len(batch)
465

466
467
468
469
470
        # Cat past
        past_key_values = torch.cat(past_key_values, dim=1)
        # Create final tensor on GPU
        cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)

471
472
473
474
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
            next_token_chooser_parameters, dtype=dtype, device=device
        )

475
476
477
        return FlashCausalLMBatch(
            batch_id=batches[0].batch_id,
            requests=requests,
478
            requests_idx_mapping=requests_idx_mapping,
479
480
481
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
482
            cu_seqlens_q=cu_seqlens_q,
483
            max_seqlen=max_seqlen,
484
485
486
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
487
488
            past_key_values=past_key_values,
            input_lengths=input_lengths,
489
490
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
491
492
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
493
            next_token_chooser=next_token_chooser,
494
            stopping_criterias=stopping_criterias,
495
            max_tokens=max_tokens,
496
497
498
499
500
501
502
503
504
505
506
507
        )

    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
        model_cls: Type[PreTrainedModel],
        model_id: str,
        revision: Optional[str] = None,
508
        quantize: Optional[str] = None,
509
        trust_remote_code: bool = False,
510
511
512
    ):
        if torch.cuda.is_available():
            device = torch.device("cuda")
513
            dtype = torch.float16
514
515
516
517
        else:
            raise NotImplementedError("FlashCausalLM is only available on GPU")

        tokenizer = AutoTokenizer.from_pretrained(
518
519
520
521
522
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
523
        )
524
525
526
527
528
        model = model_cls.from_pretrained(
            model_id,
            revision=revision,
            torch_dtype=dtype,
            load_in_8bit=quantize == "bitsandbytes",
529
            trust_remote_code=trust_remote_code,
530
        ).to(device)
531
532

        super(FlashCausalLM, self).__init__(
533
            model=model,
534
535
536
537
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
538
539
540
541
542
543
544
545
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

    def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
        return self.tokenizer.decode(
546
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
547
548
549
550
551
552
553
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        cu_seqlens: torch.Tensor,
554
        cu_seqlens_q: Optional[torch.Tensor],
555
556
        max_s: int,
        past_key_values: Optional = None,
557
        pre_allocate_past_size: Optional[int] = None,
558
        lm_head_indices: Optional[torch.Tensor] = None,
559
560
561
562
563
564
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
565
            cu_seqlens_q=cu_seqlens_q,
566
567
            max_s=max_s,
            past_key_values=past_key_values,
568
            pre_allocate_past_size=pre_allocate_past_size,
569
            lm_head_indices=lm_head_indices,
570
571
572
573
574
575
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
576
        prefill = batch.past_key_values is None
577
        prefill_logprobs = batch.prefill_next_token_indices is not None
578
        single_request = len(batch) == 1
579

580
        if prefill and single_request:
581
582
583
584
585
586
587
588
            # Ask to pre-allocate kv to its max size
            # == number of tokens + max_new_tokens
            pre_allocate_past_size = (
                batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
            )
        else:
            pre_allocate_past_size = None

589
        out, present = self.forward(
590
591
592
593
            batch.input_ids,
            batch.position_ids,
            batch.cu_seqlens,
            batch.cu_seqlens_q,
594
            batch.max_seqlen,
595
            batch.past_key_values,
596
            pre_allocate_past_size,
597
            batch.prefill_head_indices,
598
599
        )

600
601
        if prefill:
            next_token_logits = (
602
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
603
604
605
606
607
608
609
610
            )
        else:
            next_token_logits = out

        next_input_ids, next_token_logprobs = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
        )

611
        if prefill:
612
            if len(batch) > 1 and prefill_logprobs:
613
614
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
615
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636

            # Create batch.cu_seqlens_q for decode
            batch.cu_seqlens_q = torch.arange(
                0, len(batch) + 1, device=self.device, dtype=torch.int32
            )
            next_position_ids = batch.position_ids.new_empty(len(batch))
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

        # Prepare past for next decode
        if len(batch) > 1:
            # Used to slice next batch past
            past_indices = torch.empty(
                present.shape[1], dtype=torch.int64, device=self.device
            )
            batch.past_key_values = present.new_empty(
                (
                    present.shape[0],
                    present.shape[1] + len(batch.requests),
                    *present.shape[2:],
637
                )
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
            )

            # It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow
            # and will run asynchronously while we do the next for loop
            cumulative_length = 0
            for i, input_length in enumerate(batch.input_lengths):
                # Indexing metadata
                start_index = cumulative_length
                end_index = cumulative_length + input_length

                # Indices to copy present at the correct place in past_key_values
                torch.arange(
                    start_index + i,
                    end_index + i,
                    dtype=torch.int64,
                    device=self.device,
                    out=past_indices[start_index:end_index],
                )
                cumulative_length += input_length

            # Copy from present to past_key_values
            batch.past_key_values[:, past_indices] = present

        # Initialize past_key_values in prefill for len(batch) == 1
        elif prefill:
            # present is already pre-padded
            batch.past_key_values = present
665
666
667
668
669
670

        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
671
        stopped = True
672
673
674
675
676
677
678

        # Zipped iterator
        iterator = zip(
            batch.input_lengths,
            batch.all_input_ids,
        )

679
680
681
682
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

683
684
685
686
687
688
689
690
        # For each member of the batch
        for i, (
            input_length,
            all_input_ids,
        ) in enumerate(iterator):
            start_index = cumulative_length
            end_index = cumulative_length + input_length

691
            if prefill:
692
693
694
695
696
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

697
698
699
700
701
702
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
703
704
705
706
707
708
709
710
711
712
                if prefill_logprobs:
                    if len(batch) > 1:
                        prefill_tokens_indices[
                            out_start_index : out_end_index - 1
                        ] = batch.input_ids[start_index + 1 : start_index + out_length]
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
713

714
            batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
715
716
717
718
719
720
721
722

            cumulative_length += input_length

        # Set values in batch
        batch.input_ids = next_input_ids
        batch.position_ids = next_position_ids + 1
        batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q

723
        if prefill and prefill_logprobs:
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
        next_token_ids = batch.input_ids.tolist()

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
740
741
            batch.prefix_offsets,
            batch.read_offsets,
742
743
744
            batch.stopping_criterias,
            batch.all_input_ids,
            batch.all_input_ids_tensor,
745
746
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
747
748
749
750
751
752
753
754
            next_token_ids,
            next_token_logprobs,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
755
756
            prefix_offset,
            read_offset,
757
758
759
            stopping_criteria,
            all_input_ids,
            all_input_ids_tensor,
760
761
            do_sample,
            seed,
762
763
764
            next_token_id,
            next_token_logprob,
        ) in enumerate(iterator):
765
            # Append next token to all tokens
766
            all_input_ids.append(next_token_id)
767
768

            # Generated token
769
            next_token_text, prefix_offset, read_offset = self.decode_token(
770
                all_input_ids,
771
772
                prefix_offset,
                read_offset,
773
774
775
776
            )

            # Evaluate stopping criteria
            stop, reason = stopping_criteria(
777
                next_token_id,
778
779
780
                next_token_text,
            )

781
            if not stop:
782
                stopped = False
783

784
785
786
787
788
789
790
791
792
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
                    output_text = self.decode(
                        all_input_ids[-stopping_criteria.current_tokens :]
                    )
                    generated_text = GeneratedText(
793
794
795
796
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
797
798
799
800
801
                    )
                else:
                    generated_text = None

                # Prefill
802
803
804
805
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

806
807
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
808
                        out_start_index : out_end_index - 1
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    prefill_tokens = PrefillTokens(
                        prefill_token_ids, request_prefill_logprobs, prefill_texts
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id,
                    next_token_logprob,
                    next_token_text,
                    next_token_id in self.all_special_ids,
                    generated_text,
830
831
                )

832
                generations.append(generation)
833

834
            new_input_length = input_length + 1
835

836
837
            # Update values
            batch.input_lengths[i] = new_input_length
838
839
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
840
841
            batch.all_input_ids[i] = all_input_ids

842
843
844
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
845
846
        batch.max_seqlen = batch.max_seqlen + 1

847
848
        # No need to return a batch if we know that all requests stopped
        return generations, batch if not stopped else None