flash_causal_lm.py 25.5 KB
Newer Older
1
2
3
import torch
import torch.distributed

4
5
import numpy as np

6
7
8
9
10
from torch.nn import functional as F

from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
11
from typing import Optional, Tuple, List, Type, Union, Dict
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
    NextTokenChooser,
    StoppingCriteria,
    Sampling,
)

tracer = trace.get_tracer(__name__)


@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
34
35
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
36
37

    # Decoder values
38
39
40
    input_ids: torch.Tensor
    position_ids: torch.Tensor

41
    # cumulative sequence lengths
42
43
44
45
46
    cu_seqlens: torch.Tensor
    # cumulative query sequence lengths, only used in decode
    cu_seqlens_q: Optional[torch.Tensor]
    # past key values, only used in decode
    past_key_values: Optional[torch.Tensor]
47
48
49
50
51
52
53
54
    max_seqlen: int

    # All tokens
    all_input_ids: List[List[int]]
    all_input_ids_tensor: List[torch.Tensor]

    # Lengths of all generations present in the batch
    input_lengths: List[int]
55
56
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
57
58
59
60
61

    # Generation helpers
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

62
63
64
    # Maximum number of tokens this batch will grow to
    max_tokens: int

65
66
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
67
            id=self.batch_id,
68
            request_ids=[r.id for r in self.requests],
69
70
            size=len(self),
            max_tokens=self.max_tokens,
71
72
73
74
75
76
77
78
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
79
    ) -> "FlashCausalLMBatch":
80
81
82
83
84
        position_ids = []
        cu_seqlens = [0]
        max_seqlen = 0

        input_lengths = []
85
86
        prefix_offsets = []
        read_offsets = []
87
        all_input_ids = []
88
        requests_idx_mapping = {}
89
90
91
92
93
94
95

        next_token_choosers = []
        stopping_criterias = []

        # Cumulative length
        cumulative_length = 0

96
97
        max_tokens = 0

98
        # Parse batch
99
100
101
102
        for i, r in enumerate(pb.requests):
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

103
104
105
            tokenized_input = tokenizer(
                r.inputs, truncation=True, max_length=r.truncate
            )["input_ids"]
106

107
108
109
            input_length = len(tokenized_input)
            max_seqlen = max(max_seqlen, input_length)
            input_lengths.append(input_length)
110

111
112
            prefix_offsets.append(0)
            read_offsets.append(input_length)
113

114
            all_input_ids.append(tokenized_input)
115
116

            # Position ids
117
            position_ids.append(np.arange(0, input_length))
118
119
120
121
122

            # Add cumulative lengths of all previous inputs
            cu_seqlens.append(cumulative_length + input_length)

            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
123

124
125
126
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
127
            max_new_tokens = stopping_criteria.max_new_tokens
128
            stopping_criterias.append(stopping_criteria)
129

130
131
            # Update
            cumulative_length += input_length
132
            max_tokens += input_length + max_new_tokens
133

134
135
136
137
138
139
140
141
142
        # Create tensors on device
        input_ids = torch.tensor(
            np.concatenate(all_input_ids), dtype=torch.int64, device=device
        )
        position_ids = torch.tensor(
            np.concatenate(position_ids), dtype=torch.int32, device=device
        )
        cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)

143
144
145
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
146
            requests_idx_mapping=requests_idx_mapping,
147
148
149
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
150
            cu_seqlens_q=None,
151
152
153
            max_seqlen=max_seqlen,
            past_key_values=None,
            input_lengths=input_lengths,
154
155
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
156
            all_input_ids=all_input_ids,
157
            all_input_ids_tensor=[],
158
159
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
160
            max_tokens=max_tokens,
161
162
        )

163
    @tracer.start_as_current_span("filter")
164
165
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
166
167
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
168
        if len(request_ids) == len(self):
169
170
            return self

171
        single_request = len(request_ids) == 1
172

173
174
175
176
177
178
        # Cumulative length
        cumulative_length = 0

        # New values after filtering
        requests_idx_mapping = {}

179
180
        input_ids = self.input_ids.new_empty(len(request_ids))
        position_ids = self.position_ids.new_empty(len(request_ids))
181
        # Create on CPU to only move to GPU once instead of at every copy
182
        cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32)
183
        cu_seqlens_q = torch.arange(
184
            0, len(request_ids) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
185
        )
186
187
188
        max_seqlen = 0
        past_key_values = []

189
        requests = []
190
191
192
        all_input_ids = []
        all_input_ids_tensor = []

193
        input_lengths = []
194
195
        prefix_offsets = []
        read_offsets = []
196

197
198
199
        next_token_choosers = []
        stopping_criterias = []

200
201
        max_tokens = 0

202
203
204
205
206
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
207
208
209
210

            # Get length
            request_input_length = self.input_lengths[idx]

211
212
213
214
215
216
            # Copy tensors (GPU)
            input_ids[i] = self.input_ids[idx]
            position_ids[i] = self.position_ids[idx]

            # Copy to tensor (CPU)
            cu_seqlens[i + 1] = cumulative_length + request_input_length
217
            max_seqlen = max(max_seqlen, request_input_length)
218

219
220
221
222
            # Slice from past
            past_key_values.append(
                self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]]
            )
223
224
225
226
227

            all_input_ids.append(self.all_input_ids[idx])
            all_input_ids_tensor.append(self.all_input_ids_tensor[idx])

            input_lengths.append(request_input_length)
228
229
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
230
231

            next_token_choosers.append(self.next_token_choosers[idx])
232
233
234

            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
235
236

            cumulative_length += request_input_length
237
238
239
            max_tokens += request_input_length + (
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
240

241
242
        if single_request:
            # Preallocate tensor for bs = 1 case
243
            past_key_values = F.pad(
244
                past_key_values[0],
245
246
247
248
249
250
251
252
253
254
255
                (
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    stopping_criterias[0].max_new_tokens
                    - stopping_criterias[0].current_tokens,
                ),
256
            )
257
258
259
260
261
262
        else:
            # Cat all past
            past_key_values = torch.cat(past_key_values, dim=1)

        # Move to GPU now that we have the whole tensor
        cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
263

264
265
266
267
268
269
270
        return FlashCausalLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
271
            cu_seqlens_q=cu_seqlens_q,
272
273
274
            max_seqlen=max_seqlen,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
275
276
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
277
278
279
280
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
281
            max_tokens=max_tokens,
282
283
284
285
286
287
288
289
290
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

291
292
293
294
295
296
        total_batch_size = sum([len(b) for b in batches])

        device = batches[0].input_ids.device

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
297
        cu_seqlens = [0]
298
299
300
        cu_seqlens_q = torch.arange(
            0, total_batch_size + 1, device=device, dtype=torch.int32
        )
301
302
303
        max_seqlen = 0
        past_key_values = []

304
305
306
307
        all_input_ids = []
        all_input_ids_tensor = []

        input_lengths = []
308
309
        prefix_offsets = []
        read_offsets = []
310
311
312
313

        next_token_choosers = []
        stopping_criterias = []

314
        # Cumulative length
315
316
        cumulative_batch_size = 0
        cumulative_length = 0
317
        max_tokens = 0
318
319
320

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
321
322
323
324
325
326
327
328

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

329
330
331
332
333
334
335
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids

336
337
338
            # Add cumulative lengths of all previous inputs
            cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
            max_seqlen = max(max_seqlen, batch.max_seqlen)
339

340
            if len(batch) != 1:
341
                past_key_values.append(batch.past_key_values)
342
            else:
343
344
345
346
347
                # past was pre-allocated for this batch
                # We need to slice to remove the padding
                past_key_values.append(
                    batch.past_key_values[:, : batch.input_lengths[0]]
                )
348
349
350
351

            all_input_ids.extend(batch.all_input_ids)
            all_input_ids_tensor.extend(batch.all_input_ids_tensor)

352
            input_lengths.extend(batch.input_lengths)
353
354
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
355

356
357
358
359
360
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Update
            cumulative_length += batch.cu_seqlens[-1]
361
            cumulative_batch_size += len(batch)
362
            max_tokens += batch.max_tokens
363

364
365
366
367
368
        # Cat past
        past_key_values = torch.cat(past_key_values, dim=1)
        # Create final tensor on GPU
        cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)

369
370
371
        return FlashCausalLMBatch(
            batch_id=batches[0].batch_id,
            requests=requests,
372
            requests_idx_mapping=requests_idx_mapping,
373
374
375
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
376
            cu_seqlens_q=cu_seqlens_q,
377
378
379
            max_seqlen=max_seqlen,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
380
381
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
382
383
384
385
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
386
            max_tokens=max_tokens,
387
388
389
390
391
392
393
394
395
396
397
398
        )

    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
        model_cls: Type[PreTrainedModel],
        model_id: str,
        revision: Optional[str] = None,
399
        quantize: Optional[str] = None,
400
        trust_remote_code: bool = False,
401
402
403
    ):
        if torch.cuda.is_available():
            device = torch.device("cuda")
404
            dtype = torch.float16
405
406
407
408
        else:
            raise NotImplementedError("FlashCausalLM is only available on GPU")

        tokenizer = AutoTokenizer.from_pretrained(
409
410
411
412
413
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
414
        )
415
416
417
418
419
        model = model_cls.from_pretrained(
            model_id,
            revision=revision,
            torch_dtype=dtype,
            load_in_8bit=quantize == "bitsandbytes",
420
            trust_remote_code=trust_remote_code,
421
        ).to(device)
422
423

        super(FlashCausalLM, self).__init__(
424
            model=model,
425
426
427
428
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
429
430
431
432
433
434
435
436
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

    def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
        return self.tokenizer.decode(
437
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
438
439
440
441
442
443
444
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        cu_seqlens: torch.Tensor,
445
        cu_seqlens_q: Optional[torch.Tensor],
446
447
        max_s: int,
        past_key_values: Optional = None,
448
        pre_allocate_past_size: Optional[int] = None,
449
450
451
452
453
454
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
455
            cu_seqlens_q=cu_seqlens_q,
456
457
            max_s=max_s,
            past_key_values=past_key_values,
458
            pre_allocate_past_size=pre_allocate_past_size,
459
460
461
462
463
464
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
465
        prefill = batch.past_key_values is None
466

467
        if prefill and len(batch) == 1:
468
469
470
471
472
473
474
475
            # Ask to pre-allocate kv to its max size
            # == number of tokens + max_new_tokens
            pre_allocate_past_size = (
                batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
            )
        else:
            pre_allocate_past_size = None

476
        out, present = self.forward(
477
478
479
480
            batch.input_ids,
            batch.position_ids,
            batch.cu_seqlens,
            batch.cu_seqlens_q,
481
            batch.max_seqlen,
482
            batch.past_key_values,
483
            pre_allocate_past_size,
484
485
        )

486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        if prefill:
            if len(batch) > 1:
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
                prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids))

            # Create batch.cu_seqlens_q for decode
            batch.cu_seqlens_q = torch.arange(
                0, len(batch) + 1, device=self.device, dtype=torch.int32
            )
            next_input_ids = batch.input_ids.new_empty(len(batch))
            next_position_ids = batch.position_ids.new_empty(len(batch))
        else:
            prefill_logprobs = None
            next_input_ids = batch.input_ids
            next_position_ids = batch.position_ids

        next_token_logprobs = out.new_empty(len(batch))

        # Prepare past for next decode
        if len(batch) > 1:
            # Used to slice next batch past
            past_indices = torch.empty(
                present.shape[1], dtype=torch.int64, device=self.device
            )
            batch.past_key_values = present.new_empty(
                (
                    present.shape[0],
                    present.shape[1] + len(batch.requests),
                    *present.shape[2:],
516
                )
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
            )

            # It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow
            # and will run asynchronously while we do the next for loop
            cumulative_length = 0
            for i, input_length in enumerate(batch.input_lengths):
                # Indexing metadata
                start_index = cumulative_length
                end_index = cumulative_length + input_length

                # Indices to copy present at the correct place in past_key_values
                torch.arange(
                    start_index + i,
                    end_index + i,
                    dtype=torch.int64,
                    device=self.device,
                    out=past_indices[start_index:end_index],
                )
                cumulative_length += input_length

            # Copy from present to past_key_values
            batch.past_key_values[:, past_indices] = present

        # Initialize past_key_values in prefill for len(batch) == 1
        elif prefill:
            # present is already pre-padded
            batch.past_key_values = present
544
545
546
547
548
549

        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
550
        stopped = True
551
552
553
554
555
556
557
558
559

        # Zipped iterator
        iterator = zip(
            batch.input_lengths,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

560
561
562
563
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

564
565
566
567
568
569
570
571
572
573
574
        # For each member of the batch
        for i, (
            input_length,
            next_token_chooser,
            stopping_criteria,
            all_input_ids,
        ) in enumerate(iterator):
            # Indexing metadata
            start_index = cumulative_length
            end_index = cumulative_length + input_length

575
            if prefill:
576
577
                # Prefill mode
                # out is of shape [cumulative_sequence_lengths, vocab_size]
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
                # only take last token logit
                logits = out[end_index - 1 : end_index]

                # Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty)
                all_input_ids_tensor = batch.input_ids.new_empty(
                    input_length + stopping_criteria.max_new_tokens
                )
                # Copy from batch.input_ids to all_input_ids_tensor
                all_input_ids_tensor[:input_length] = batch.input_ids[
                    start_index:end_index
                ]
                batch.all_input_ids_tensor.append(all_input_ids_tensor)

                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
                if len(batch) > 1:
                    prefill_tokens_indices[
                        start_index : end_index - 1
                    ] = batch.input_ids[start_index + 1 : end_index]
                else:
                    # Set prefill_tokens_indices to the correct slice
                    prefill_tokens_indices = batch.input_ids[
                        start_index + 1 : end_index
                    ]
606
607
608
            else:
                # Decode mode
                # out is of shape [batch_size, vocab_size]
609
610
611
                logits = out[i].view(1, -1)

            all_input_ids_tensor = batch.all_input_ids_tensor[i]
612
613

            # Select next token
614
            next_token_id, logprob = next_token_chooser(
615
616
                all_input_ids_tensor[None, :input_length], logits
            )
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651

            # Add to all_input_ids_tensor
            next_token_id_squeezed = next_token_id.view(1)
            all_input_ids_tensor[input_length] = next_token_id_squeezed

            # Set values
            next_input_ids[i] = next_token_id_squeezed
            next_token_logprobs[i] = logprob[-1, next_token_id].view(1)

            cumulative_length += input_length

        # Set values in batch
        batch.input_ids = next_input_ids
        batch.position_ids = next_position_ids + 1
        batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q

        if prefill:
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
        next_token_ids = batch.input_ids.tolist()

        cumulative_length = 0

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
652
653
            batch.prefix_offsets,
            batch.read_offsets,
654
655
656
657
658
659
660
661
662
663
664
665
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
            batch.all_input_ids_tensor,
            next_token_ids,
            next_token_logprobs,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
666
667
            prefix_offset,
            read_offset,
668
669
670
671
672
673
674
675
676
            next_token_chooser,
            stopping_criteria,
            all_input_ids,
            all_input_ids_tensor,
            next_token_id,
            next_token_logprob,
        ) in enumerate(iterator):
            start_index = cumulative_length
            end_index = cumulative_length + input_length
677
678

            # Append next token to all tokens
679
            all_input_ids.append(next_token_id)
680
681

            # Generated token
682
            next_token_text, prefix_offset, read_offset = self.decode_token(
683
                all_input_ids,
684
685
                prefix_offset,
                read_offset,
686
687
688
689
            )

            # Evaluate stopping criteria
            stop, reason = stopping_criteria(
690
                next_token_id,
691
692
693
                next_token_text,
            )

694
            if not stop:
695
                stopped = False
696

697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
                    output_text = self.decode(
                        all_input_ids[-stopping_criteria.current_tokens :]
                    )
                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
                if prefill:
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
                        start_index : end_index - 1
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    prefill_tokens = PrefillTokens(
                        prefill_token_ids, request_prefill_logprobs, prefill_texts
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id,
                    next_token_logprob,
                    next_token_text,
                    next_token_id in self.all_special_ids,
                    generated_text,
743
744
                )

745
                generations.append(generation)
746

747
            new_input_length = input_length + 1
748

749
750
            # Update values
            batch.input_lengths[i] = new_input_length
751
752
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
753
            batch.all_input_ids[i] = all_input_ids
754
755
            batch.max_seqlen = batch.max_seqlen + 1
            cumulative_length += input_length
756
757
758

        # No need to return a batch if we know that all requests stopped
        return generations, batch if not stopped else None