flash_causal_lm.py 25.4 KB
Newer Older
1
2
3
import torch
import torch.distributed

4
5
import numpy as np

6
7
8
9
10
from torch.nn import functional as F

from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
11
from typing import Optional, Tuple, List, Type, Union, Dict
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
    NextTokenChooser,
    StoppingCriteria,
    Sampling,
)

tracer = trace.get_tracer(__name__)


@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
34
35
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
36
37

    # Decoder values
38
39
40
    input_ids: torch.Tensor
    position_ids: torch.Tensor

41
    # cumulative sequence lengths
42
43
44
45
46
    cu_seqlens: torch.Tensor
    # cumulative query sequence lengths, only used in decode
    cu_seqlens_q: Optional[torch.Tensor]
    # past key values, only used in decode
    past_key_values: Optional[torch.Tensor]
47
48
49
50
51
52
53
54
    max_seqlen: int

    # All tokens
    all_input_ids: List[List[int]]
    all_input_ids_tensor: List[torch.Tensor]

    # Lengths of all generations present in the batch
    input_lengths: List[int]
55
56
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
57
58
59
60
61

    # Generation helpers
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

62
63
64
    # Maximum number of tokens this batch will grow to
    max_tokens: int

65
66
    def to_pb(self) -> generate_pb2.Batch:
        return generate_pb2.Batch(
67
68
69
70
            id=self.batch_id,
            requests=self.requests,
            size=len(self),
            max_tokens=self.max_tokens,
71
72
73
74
75
76
77
78
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
79
    ) -> "FlashCausalLMBatch":
80
81
82
83
84
        position_ids = []
        cu_seqlens = [0]
        max_seqlen = 0

        input_lengths = []
85
86
        prefix_offsets = []
        read_offsets = []
87
        all_input_ids = []
88
        requests_idx_mapping = {}
89
90
91
92
93
94
95

        next_token_choosers = []
        stopping_criterias = []

        # Cumulative length
        cumulative_length = 0

96
97
        max_tokens = 0

98
        # Parse batch
99
100
101
102
        for i, r in enumerate(pb.requests):
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

103
104
105
            tokenized_input = tokenizer(
                r.inputs, truncation=True, max_length=r.truncate
            )["input_ids"]
106

107
108
109
            input_length = len(tokenized_input)
            max_seqlen = max(max_seqlen, input_length)
            input_lengths.append(input_length)
110

111
112
            prefix_offsets.append(0)
            read_offsets.append(input_length)
113

114
            all_input_ids.append(tokenized_input)
115
116

            # Position ids
117
            position_ids.append(np.arange(0, input_length))
118
119
120
121
122

            # Add cumulative lengths of all previous inputs
            cu_seqlens.append(cumulative_length + input_length)

            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
123

124
125
126
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
127
            max_new_tokens = stopping_criteria.max_new_tokens
128
            stopping_criterias.append(stopping_criteria)
129

130
131
            # Update
            cumulative_length += input_length
132
            max_tokens += input_length + max_new_tokens
133

134
135
136
137
138
139
140
141
142
        # Create tensors on device
        input_ids = torch.tensor(
            np.concatenate(all_input_ids), dtype=torch.int64, device=device
        )
        position_ids = torch.tensor(
            np.concatenate(position_ids), dtype=torch.int32, device=device
        )
        cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)

143
144
145
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
146
            requests_idx_mapping=requests_idx_mapping,
147
148
149
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
150
            cu_seqlens_q=None,
151
152
153
            max_seqlen=max_seqlen,
            past_key_values=None,
            input_lengths=input_lengths,
154
155
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
156
            all_input_ids=all_input_ids,
157
            all_input_ids_tensor=[],
158
159
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
160
            max_tokens=max_tokens,
161
162
        )

163
164
165
166
167
168
169
170
    @tracer.start_as_current_span("filter")
    def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch":
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
        if len(requests) == len(self):
            return self

171
172
        single_request = len(requests) == 1

173
174
175
176
177
178
        # Cumulative length
        cumulative_length = 0

        # New values after filtering
        requests_idx_mapping = {}

179
180
181
182
183
184
185
        input_ids = self.input_ids.new_empty(len(requests))
        position_ids = self.position_ids.new_empty(len(requests))
        # Create on CPU to only move to GPU once instead of at every copy
        cu_seqlens = torch.zeros(len(requests) + 1, dtype=torch.int32)
        cu_seqlens_q = torch.arange(
            0, len(requests) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
        )
186
187
188
189
190
191
        max_seqlen = 0
        past_key_values = []

        all_input_ids = []
        all_input_ids_tensor = []

192
        input_lengths = []
193
194
        prefix_offsets = []
        read_offsets = []
195

196
197
198
        next_token_choosers = []
        stopping_criterias = []

199
200
        max_tokens = 0

201
202
203
204
205
206
207
        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i

            # Get length
            request_input_length = self.input_lengths[idx]

208
209
210
211
212
213
            # Copy tensors (GPU)
            input_ids[i] = self.input_ids[idx]
            position_ids[i] = self.position_ids[idx]

            # Copy to tensor (CPU)
            cu_seqlens[i + 1] = cumulative_length + request_input_length
214
            max_seqlen = max(max_seqlen, request_input_length)
215

216
217
218
219
            # Slice from past
            past_key_values.append(
                self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]]
            )
220
221
222
223
224

            all_input_ids.append(self.all_input_ids[idx])
            all_input_ids_tensor.append(self.all_input_ids_tensor[idx])

            input_lengths.append(request_input_length)
225
226
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
227
228

            next_token_choosers.append(self.next_token_choosers[idx])
229
230
231

            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
232
233

            cumulative_length += request_input_length
234
235
236
            max_tokens += request_input_length + (
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
237

238
239
        if single_request:
            # Preallocate tensor for bs = 1 case
240
            past_key_values = F.pad(
241
                past_key_values[0],
242
243
244
245
246
247
248
249
250
251
252
                (
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    stopping_criterias[0].max_new_tokens
                    - stopping_criterias[0].current_tokens,
                ),
253
            )
254
255
256
257
258
259
        else:
            # Cat all past
            past_key_values = torch.cat(past_key_values, dim=1)

        # Move to GPU now that we have the whole tensor
        cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
260

261
262
263
264
265
266
267
        return FlashCausalLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
268
            cu_seqlens_q=cu_seqlens_q,
269
270
271
            max_seqlen=max_seqlen,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
272
273
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
274
275
276
277
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
278
            max_tokens=max_tokens,
279
280
281
282
283
284
285
286
287
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

288
289
290
291
292
293
        total_batch_size = sum([len(b) for b in batches])

        device = batches[0].input_ids.device

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
294
        cu_seqlens = [0]
295
296
297
        cu_seqlens_q = torch.arange(
            0, total_batch_size + 1, device=device, dtype=torch.int32
        )
298
299
300
        max_seqlen = 0
        past_key_values = []

301
302
303
304
        all_input_ids = []
        all_input_ids_tensor = []

        input_lengths = []
305
306
        prefix_offsets = []
        read_offsets = []
307
308
309
310

        next_token_choosers = []
        stopping_criterias = []

311
        # Cumulative length
312
313
        cumulative_batch_size = 0
        cumulative_length = 0
314
        max_tokens = 0
315
316
317

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
318
319
320
321
322
323
324
325

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

326
327
328
329
330
331
332
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids

333
334
335
            # Add cumulative lengths of all previous inputs
            cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
            max_seqlen = max(max_seqlen, batch.max_seqlen)
336

337
            if len(batch) != 1:
338
                past_key_values.append(batch.past_key_values)
339
            else:
340
341
342
343
344
                # past was pre-allocated for this batch
                # We need to slice to remove the padding
                past_key_values.append(
                    batch.past_key_values[:, : batch.input_lengths[0]]
                )
345
346
347
348

            all_input_ids.extend(batch.all_input_ids)
            all_input_ids_tensor.extend(batch.all_input_ids_tensor)

349
            input_lengths.extend(batch.input_lengths)
350
351
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
352

353
354
355
356
357
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Update
            cumulative_length += batch.cu_seqlens[-1]
358
            cumulative_batch_size += len(batch)
359
            max_tokens += batch.max_tokens
360

361
362
363
364
365
        # Cat past
        past_key_values = torch.cat(past_key_values, dim=1)
        # Create final tensor on GPU
        cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)

366
367
368
        return FlashCausalLMBatch(
            batch_id=batches[0].batch_id,
            requests=requests,
369
            requests_idx_mapping=requests_idx_mapping,
370
371
372
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
373
            cu_seqlens_q=cu_seqlens_q,
374
375
376
            max_seqlen=max_seqlen,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
377
378
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
379
380
381
382
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
383
            max_tokens=max_tokens,
384
385
386
387
388
389
390
391
392
393
394
395
        )

    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
        model_cls: Type[PreTrainedModel],
        model_id: str,
        revision: Optional[str] = None,
396
        quantize: Optional[str] = None,
397
        trust_remote_code: bool = False,
398
399
400
    ):
        if torch.cuda.is_available():
            device = torch.device("cuda")
401
            dtype = torch.float16
402
403
404
405
        else:
            raise NotImplementedError("FlashCausalLM is only available on GPU")

        tokenizer = AutoTokenizer.from_pretrained(
406
407
408
409
410
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
411
        )
412
413
414
415
416
        model = model_cls.from_pretrained(
            model_id,
            revision=revision,
            torch_dtype=dtype,
            load_in_8bit=quantize == "bitsandbytes",
417
            trust_remote_code=trust_remote_code,
418
        ).to(device)
419
420

        super(FlashCausalLM, self).__init__(
421
            model=model,
422
423
424
425
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
426
427
428
429
430
431
432
433
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

    def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
        return self.tokenizer.decode(
434
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
435
436
437
438
439
440
441
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        cu_seqlens: torch.Tensor,
442
        cu_seqlens_q: Optional[torch.Tensor],
443
444
        max_s: int,
        past_key_values: Optional = None,
445
        pre_allocate_past_size: Optional[int] = None,
446
447
448
449
450
451
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
452
            cu_seqlens_q=cu_seqlens_q,
453
454
            max_s=max_s,
            past_key_values=past_key_values,
455
            pre_allocate_past_size=pre_allocate_past_size,
456
457
458
459
460
461
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
462
        prefill = batch.past_key_values is None
463

464
        if prefill and len(batch) == 1:
465
466
467
468
469
470
471
472
            # Ask to pre-allocate kv to its max size
            # == number of tokens + max_new_tokens
            pre_allocate_past_size = (
                batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
            )
        else:
            pre_allocate_past_size = None

473
        out, present = self.forward(
474
475
476
477
            batch.input_ids,
            batch.position_ids,
            batch.cu_seqlens,
            batch.cu_seqlens_q,
478
            batch.max_seqlen,
479
            batch.past_key_values,
480
            pre_allocate_past_size,
481
482
        )

483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        if prefill:
            if len(batch) > 1:
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
                prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids))

            # Create batch.cu_seqlens_q for decode
            batch.cu_seqlens_q = torch.arange(
                0, len(batch) + 1, device=self.device, dtype=torch.int32
            )
            next_input_ids = batch.input_ids.new_empty(len(batch))
            next_position_ids = batch.position_ids.new_empty(len(batch))
        else:
            prefill_logprobs = None
            next_input_ids = batch.input_ids
            next_position_ids = batch.position_ids

        next_token_logprobs = out.new_empty(len(batch))

        # Prepare past for next decode
        if len(batch) > 1:
            # Used to slice next batch past
            past_indices = torch.empty(
                present.shape[1], dtype=torch.int64, device=self.device
            )
            batch.past_key_values = present.new_empty(
                (
                    present.shape[0],
                    present.shape[1] + len(batch.requests),
                    *present.shape[2:],
513
                )
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
            )

            # It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow
            # and will run asynchronously while we do the next for loop
            cumulative_length = 0
            for i, input_length in enumerate(batch.input_lengths):
                # Indexing metadata
                start_index = cumulative_length
                end_index = cumulative_length + input_length

                # Indices to copy present at the correct place in past_key_values
                torch.arange(
                    start_index + i,
                    end_index + i,
                    dtype=torch.int64,
                    device=self.device,
                    out=past_indices[start_index:end_index],
                )
                cumulative_length += input_length

            # Copy from present to past_key_values
            batch.past_key_values[:, past_indices] = present

        # Initialize past_key_values in prefill for len(batch) == 1
        elif prefill:
            # present is already pre-padded
            batch.past_key_values = present
541
542
543
544
545
546

        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
547
        stopped = True
548
549
550
551
552
553
554
555
556

        # Zipped iterator
        iterator = zip(
            batch.input_lengths,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

557
558
559
560
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

561
562
563
564
565
566
567
568
569
570
571
        # For each member of the batch
        for i, (
            input_length,
            next_token_chooser,
            stopping_criteria,
            all_input_ids,
        ) in enumerate(iterator):
            # Indexing metadata
            start_index = cumulative_length
            end_index = cumulative_length + input_length

572
            if prefill:
573
574
                # Prefill mode
                # out is of shape [cumulative_sequence_lengths, vocab_size]
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
                # only take last token logit
                logits = out[end_index - 1 : end_index]

                # Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty)
                all_input_ids_tensor = batch.input_ids.new_empty(
                    input_length + stopping_criteria.max_new_tokens
                )
                # Copy from batch.input_ids to all_input_ids_tensor
                all_input_ids_tensor[:input_length] = batch.input_ids[
                    start_index:end_index
                ]
                batch.all_input_ids_tensor.append(all_input_ids_tensor)

                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
                if len(batch) > 1:
                    prefill_tokens_indices[
                        start_index : end_index - 1
                    ] = batch.input_ids[start_index + 1 : end_index]
                else:
                    # Set prefill_tokens_indices to the correct slice
                    prefill_tokens_indices = batch.input_ids[
                        start_index + 1 : end_index
                    ]
603
604
605
            else:
                # Decode mode
                # out is of shape [batch_size, vocab_size]
606
607
608
                logits = out[i].view(1, -1)

            all_input_ids_tensor = batch.all_input_ids_tensor[i]
609
610

            # Select next token
611
            next_token_id, logprob = next_token_chooser(
612
613
                all_input_ids_tensor[None, :input_length], logits
            )
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648

            # Add to all_input_ids_tensor
            next_token_id_squeezed = next_token_id.view(1)
            all_input_ids_tensor[input_length] = next_token_id_squeezed

            # Set values
            next_input_ids[i] = next_token_id_squeezed
            next_token_logprobs[i] = logprob[-1, next_token_id].view(1)

            cumulative_length += input_length

        # Set values in batch
        batch.input_ids = next_input_ids
        batch.position_ids = next_position_ids + 1
        batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q

        if prefill:
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
        next_token_ids = batch.input_ids.tolist()

        cumulative_length = 0

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
649
650
            batch.prefix_offsets,
            batch.read_offsets,
651
652
653
654
655
656
657
658
659
660
661
662
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
            batch.all_input_ids_tensor,
            next_token_ids,
            next_token_logprobs,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
663
664
            prefix_offset,
            read_offset,
665
666
667
668
669
670
671
672
673
            next_token_chooser,
            stopping_criteria,
            all_input_ids,
            all_input_ids_tensor,
            next_token_id,
            next_token_logprob,
        ) in enumerate(iterator):
            start_index = cumulative_length
            end_index = cumulative_length + input_length
674
675

            # Append next token to all tokens
676
            all_input_ids.append(next_token_id)
677
678

            # Generated token
679
            next_token_text, prefix_offset, read_offset = self.decode_token(
680
                all_input_ids,
681
682
                prefix_offset,
                read_offset,
683
684
685
686
            )

            # Evaluate stopping criteria
            stop, reason = stopping_criteria(
687
                next_token_id,
688
689
690
                next_token_text,
            )

691
            if not stop:
692
                stopped = False
693

694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
                    output_text = self.decode(
                        all_input_ids[-stopping_criteria.current_tokens :]
                    )
                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
                if prefill:
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
                        start_index : end_index - 1
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    prefill_tokens = PrefillTokens(
                        prefill_token_ids, request_prefill_logprobs, prefill_texts
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id,
                    next_token_logprob,
                    next_token_text,
                    next_token_id in self.all_special_ids,
                    generated_text,
740
741
                )

742
                generations.append(generation)
743

744
            new_input_length = input_length + 1
745

746
747
            # Update values
            batch.input_lengths[i] = new_input_length
748
749
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
750
            batch.all_input_ids[i] = all_input_ids
751
752
            batch.max_seqlen = batch.max_seqlen + 1
            cumulative_length += input_length
753
754
755

        # No need to return a batch if we know that all requests stopped
        return generations, batch if not stopped else None