flash_causal_lm.py 25.2 KB
Newer Older
1
2
3
import torch
import torch.distributed

4
5
import numpy as np

6
7
8
9
10
from torch.nn import functional as F

from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
11
from typing import Optional, Tuple, List, Type, Union, Dict
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
    NextTokenChooser,
    StoppingCriteria,
    Sampling,
)

tracer = trace.get_tracer(__name__)


@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
34
35
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
36
37

    # Decoder values
38
39
40
    input_ids: torch.Tensor
    position_ids: torch.Tensor

41
    # cumulative sequence lengths
42
43
44
45
46
    cu_seqlens: torch.Tensor
    # cumulative query sequence lengths, only used in decode
    cu_seqlens_q: Optional[torch.Tensor]
    # past key values, only used in decode
    past_key_values: Optional[torch.Tensor]
47
48
49
50
51
52
53
54
    max_seqlen: int

    # All tokens
    all_input_ids: List[List[int]]
    all_input_ids_tensor: List[torch.Tensor]

    # Lengths of all generations present in the batch
    input_lengths: List[int]
55
56
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
57
58
59
60
61

    # Generation helpers
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

62
63
64
    # Maximum number of tokens this batch will grow to
    max_tokens: int

65
66
    def to_pb(self) -> generate_pb2.Batch:
        return generate_pb2.Batch(
67
68
69
70
            id=self.batch_id,
            requests=self.requests,
            size=len(self),
            max_tokens=self.max_tokens,
71
72
73
74
75
76
77
78
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
79
    ) -> "FlashCausalLMBatch":
80
81
82
83
84
        position_ids = []
        cu_seqlens = [0]
        max_seqlen = 0

        input_lengths = []
85
86
        prefix_offsets = []
        read_offsets = []
87
        all_input_ids = []
88
        requests_idx_mapping = {}
89
90
91
92
93
94
95

        next_token_choosers = []
        stopping_criterias = []

        # Cumulative length
        cumulative_length = 0

96
97
        max_tokens = 0

98
        # Parse batch
99
100
101
102
        for i, r in enumerate(pb.requests):
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

103
104
105
            tokenized_input = tokenizer(
                r.inputs, truncation=True, max_length=r.truncate
            )["input_ids"]
106

107
108
109
            input_length = len(tokenized_input)
            max_seqlen = max(max_seqlen, input_length)
            input_lengths.append(input_length)
110

111
112
            prefix_offsets.append(0)
            read_offsets.append(input_length)
113

114
            all_input_ids.append(tokenized_input)
115
116

            # Position ids
117
            position_ids.append(np.arange(0, input_length))
118
119
120
121
122

            # Add cumulative lengths of all previous inputs
            cu_seqlens.append(cumulative_length + input_length)

            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
123

124
125
126
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
127
            max_new_tokens = stopping_criteria.max_new_tokens
128
            stopping_criterias.append(stopping_criteria)
129

130
131
            # Update
            cumulative_length += input_length
132
            max_tokens += input_length + max_new_tokens
133

134
135
136
137
138
139
140
141
142
        # Create tensors on device
        input_ids = torch.tensor(
            np.concatenate(all_input_ids), dtype=torch.int64, device=device
        )
        position_ids = torch.tensor(
            np.concatenate(position_ids), dtype=torch.int32, device=device
        )
        cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)

143
144
145
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
146
            requests_idx_mapping=requests_idx_mapping,
147
148
149
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
150
            cu_seqlens_q=None,
151
152
153
            max_seqlen=max_seqlen,
            past_key_values=None,
            input_lengths=input_lengths,
154
155
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
156
            all_input_ids=all_input_ids,
157
            all_input_ids_tensor=[],
158
159
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
160
            max_tokens=max_tokens,
161
162
        )

163
164
165
166
167
168
169
170
    @tracer.start_as_current_span("filter")
    def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch":
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
        if len(requests) == len(self):
            return self

171
172
        single_request = len(requests) == 1

173
174
175
176
177
178
        # Cumulative length
        cumulative_length = 0

        # New values after filtering
        requests_idx_mapping = {}

179
180
181
182
183
184
185
        input_ids = self.input_ids.new_empty(len(requests))
        position_ids = self.position_ids.new_empty(len(requests))
        # Create on CPU to only move to GPU once instead of at every copy
        cu_seqlens = torch.zeros(len(requests) + 1, dtype=torch.int32)
        cu_seqlens_q = torch.arange(
            0, len(requests) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
        )
186
187
188
189
190
191
        max_seqlen = 0
        past_key_values = []

        all_input_ids = []
        all_input_ids_tensor = []

192
        input_lengths = []
193
194
        prefix_offsets = []
        read_offsets = []
195

196
197
198
        next_token_choosers = []
        stopping_criterias = []

199
200
        max_tokens = 0

201
202
203
204
205
206
207
        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i

            # Get length
            request_input_length = self.input_lengths[idx]

208
209
210
211
212
213
            # Copy tensors (GPU)
            input_ids[i] = self.input_ids[idx]
            position_ids[i] = self.position_ids[idx]

            # Copy to tensor (CPU)
            cu_seqlens[i + 1] = cumulative_length + request_input_length
214
            max_seqlen = max(max_seqlen, request_input_length)
215

216
217
218
219
            # Slice from past
            past_key_values.append(
                self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]]
            )
220
221
222
223
224

            all_input_ids.append(self.all_input_ids[idx])
            all_input_ids_tensor.append(self.all_input_ids_tensor[idx])

            input_lengths.append(request_input_length)
225
226
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
227
228

            next_token_choosers.append(self.next_token_choosers[idx])
229
230
231

            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
232
233

            cumulative_length += request_input_length
234
235
236
            max_tokens += request_input_length + (
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
237

238
239
        if single_request:
            # Preallocate tensor for bs = 1 case
240
            past_key_values = F.pad(
241
                past_key_values[0],
242
243
244
245
246
247
248
249
250
251
252
                (
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    stopping_criterias[0].max_new_tokens
                    - stopping_criterias[0].current_tokens,
                ),
253
            )
254
255
256
257
258
259
        else:
            # Cat all past
            past_key_values = torch.cat(past_key_values, dim=1)

        # Move to GPU now that we have the whole tensor
        cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
260

261
262
263
264
265
266
267
        return FlashCausalLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
268
            cu_seqlens_q=cu_seqlens_q,
269
270
271
            max_seqlen=max_seqlen,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
272
273
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
274
275
276
277
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
278
            max_tokens=max_tokens,
279
280
281
282
283
284
285
286
287
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

288
289
290
291
292
293
        total_batch_size = sum([len(b) for b in batches])

        device = batches[0].input_ids.device

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
294
        cu_seqlens = [0]
295
296
297
        cu_seqlens_q = torch.arange(
            0, total_batch_size + 1, device=device, dtype=torch.int32
        )
298
299
300
        max_seqlen = 0
        past_key_values = []

301
302
303
304
        all_input_ids = []
        all_input_ids_tensor = []

        input_lengths = []
305
306
        prefix_offsets = []
        read_offsets = []
307
308
309
310

        next_token_choosers = []
        stopping_criterias = []

311
        # Cumulative length
312
313
        cumulative_batch_size = 0
        cumulative_length = 0
314
        max_tokens = 0
315
316
317

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
318
319
320
321
322
323
324
325

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

326
327
328
329
330
331
332
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids

333
334
335
            # Add cumulative lengths of all previous inputs
            cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
            max_seqlen = max(max_seqlen, batch.max_seqlen)
336

337
            if len(batch) != 1:
338
                past_key_values.append(batch.past_key_values)
339
            else:
340
341
342
343
344
                # past was pre-allocated for this batch
                # We need to slice to remove the padding
                past_key_values.append(
                    batch.past_key_values[:, : batch.input_lengths[0]]
                )
345
346
347
348

            all_input_ids.extend(batch.all_input_ids)
            all_input_ids_tensor.extend(batch.all_input_ids_tensor)

349
            input_lengths.extend(batch.input_lengths)
350
351
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
352

353
354
355
356
357
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Update
            cumulative_length += batch.cu_seqlens[-1]
358
            cumulative_batch_size += len(batch)
359
            max_tokens += batch.max_tokens
360

361
362
363
364
365
        # Cat past
        past_key_values = torch.cat(past_key_values, dim=1)
        # Create final tensor on GPU
        cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)

366
367
368
        return FlashCausalLMBatch(
            batch_id=batches[0].batch_id,
            requests=requests,
369
            requests_idx_mapping=requests_idx_mapping,
370
371
372
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
373
            cu_seqlens_q=cu_seqlens_q,
374
375
376
            max_seqlen=max_seqlen,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
377
378
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
379
380
381
382
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
383
            max_tokens=max_tokens,
384
385
386
387
388
389
390
391
392
393
394
395
        )

    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
        model_cls: Type[PreTrainedModel],
        model_id: str,
        revision: Optional[str] = None,
396
        quantize: Optional[str] = None,
397
398
399
    ):
        if torch.cuda.is_available():
            device = torch.device("cuda")
400
            dtype = torch.float16
401
402
403
404
        else:
            raise NotImplementedError("FlashCausalLM is only available on GPU")

        tokenizer = AutoTokenizer.from_pretrained(
405
            model_id, revision=revision, padding_side="left", truncation_side="left"
406
        )
407
408
409
410
411
412
        model = model_cls.from_pretrained(
            model_id,
            revision=revision,
            torch_dtype=dtype,
            load_in_8bit=quantize == "bitsandbytes",
        ).to(device)
413
414

        super(FlashCausalLM, self).__init__(
415
            model=model,
416
417
418
419
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
420
421
422
423
424
425
426
427
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

    def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
        return self.tokenizer.decode(
428
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
429
430
431
432
433
434
435
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        cu_seqlens: torch.Tensor,
436
        cu_seqlens_q: Optional[torch.Tensor],
437
438
        max_s: int,
        past_key_values: Optional = None,
439
        pre_allocate_past_size: Optional[int] = None,
440
441
442
443
444
445
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
446
            cu_seqlens_q=cu_seqlens_q,
447
448
            max_s=max_s,
            past_key_values=past_key_values,
449
            pre_allocate_past_size=pre_allocate_past_size,
450
451
452
453
454
455
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
456
        prefill = batch.past_key_values is None
457

458
        if prefill and len(batch) == 1:
459
460
461
462
463
464
465
466
            # Ask to pre-allocate kv to its max size
            # == number of tokens + max_new_tokens
            pre_allocate_past_size = (
                batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
            )
        else:
            pre_allocate_past_size = None

467
        out, present = self.forward(
468
469
470
471
            batch.input_ids,
            batch.position_ids,
            batch.cu_seqlens,
            batch.cu_seqlens_q,
472
            batch.max_seqlen,
473
            batch.past_key_values,
474
            pre_allocate_past_size,
475
476
        )

477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
        if prefill:
            if len(batch) > 1:
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
                prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids))

            # Create batch.cu_seqlens_q for decode
            batch.cu_seqlens_q = torch.arange(
                0, len(batch) + 1, device=self.device, dtype=torch.int32
            )
            next_input_ids = batch.input_ids.new_empty(len(batch))
            next_position_ids = batch.position_ids.new_empty(len(batch))
        else:
            prefill_logprobs = None
            next_input_ids = batch.input_ids
            next_position_ids = batch.position_ids

        next_token_logprobs = out.new_empty(len(batch))

        # Prepare past for next decode
        if len(batch) > 1:
            # Used to slice next batch past
            past_indices = torch.empty(
                present.shape[1], dtype=torch.int64, device=self.device
            )
            batch.past_key_values = present.new_empty(
                (
                    present.shape[0],
                    present.shape[1] + len(batch.requests),
                    *present.shape[2:],
507
                )
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
            )

            # It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow
            # and will run asynchronously while we do the next for loop
            cumulative_length = 0
            for i, input_length in enumerate(batch.input_lengths):
                # Indexing metadata
                start_index = cumulative_length
                end_index = cumulative_length + input_length

                # Indices to copy present at the correct place in past_key_values
                torch.arange(
                    start_index + i,
                    end_index + i,
                    dtype=torch.int64,
                    device=self.device,
                    out=past_indices[start_index:end_index],
                )
                cumulative_length += input_length

            # Copy from present to past_key_values
            batch.past_key_values[:, past_indices] = present

        # Initialize past_key_values in prefill for len(batch) == 1
        elif prefill:
            # present is already pre-padded
            batch.past_key_values = present
535
536
537
538
539
540

        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
541
        stopped = True
542
543
544
545
546
547
548
549
550

        # Zipped iterator
        iterator = zip(
            batch.input_lengths,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

551
552
553
554
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

555
556
557
558
559
560
561
562
563
564
565
        # For each member of the batch
        for i, (
            input_length,
            next_token_chooser,
            stopping_criteria,
            all_input_ids,
        ) in enumerate(iterator):
            # Indexing metadata
            start_index = cumulative_length
            end_index = cumulative_length + input_length

566
            if prefill:
567
568
                # Prefill mode
                # out is of shape [cumulative_sequence_lengths, vocab_size]
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
                # only take last token logit
                logits = out[end_index - 1 : end_index]

                # Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty)
                all_input_ids_tensor = batch.input_ids.new_empty(
                    input_length + stopping_criteria.max_new_tokens
                )
                # Copy from batch.input_ids to all_input_ids_tensor
                all_input_ids_tensor[:input_length] = batch.input_ids[
                    start_index:end_index
                ]
                batch.all_input_ids_tensor.append(all_input_ids_tensor)

                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
                if len(batch) > 1:
                    prefill_tokens_indices[
                        start_index : end_index - 1
                    ] = batch.input_ids[start_index + 1 : end_index]
                else:
                    # Set prefill_tokens_indices to the correct slice
                    prefill_tokens_indices = batch.input_ids[
                        start_index + 1 : end_index
                    ]
597
598
599
            else:
                # Decode mode
                # out is of shape [batch_size, vocab_size]
600
601
602
                logits = out[i].view(1, -1)

            all_input_ids_tensor = batch.all_input_ids_tensor[i]
603
604

            # Select next token
605
            next_token_id, logprob = next_token_chooser(
606
607
                all_input_ids_tensor[None, :input_length], logits
            )
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642

            # Add to all_input_ids_tensor
            next_token_id_squeezed = next_token_id.view(1)
            all_input_ids_tensor[input_length] = next_token_id_squeezed

            # Set values
            next_input_ids[i] = next_token_id_squeezed
            next_token_logprobs[i] = logprob[-1, next_token_id].view(1)

            cumulative_length += input_length

        # Set values in batch
        batch.input_ids = next_input_ids
        batch.position_ids = next_position_ids + 1
        batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q

        if prefill:
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
        next_token_ids = batch.input_ids.tolist()

        cumulative_length = 0

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
643
644
            batch.prefix_offsets,
            batch.read_offsets,
645
646
647
648
649
650
651
652
653
654
655
656
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
            batch.all_input_ids_tensor,
            next_token_ids,
            next_token_logprobs,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
657
658
            prefix_offset,
            read_offset,
659
660
661
662
663
664
665
666
667
            next_token_chooser,
            stopping_criteria,
            all_input_ids,
            all_input_ids_tensor,
            next_token_id,
            next_token_logprob,
        ) in enumerate(iterator):
            start_index = cumulative_length
            end_index = cumulative_length + input_length
668
669

            # Append next token to all tokens
670
            all_input_ids.append(next_token_id)
671
672

            # Generated token
673
            next_token_text, prefix_offset, read_offset = self.decode_token(
674
                all_input_ids,
675
676
                prefix_offset,
                read_offset,
677
678
679
680
            )

            # Evaluate stopping criteria
            stop, reason = stopping_criteria(
681
                next_token_id,
682
683
684
                next_token_text,
            )

685
            if not stop:
686
                stopped = False
687

688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
                    output_text = self.decode(
                        all_input_ids[-stopping_criteria.current_tokens :]
                    )
                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
                if prefill:
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
                        start_index : end_index - 1
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    prefill_tokens = PrefillTokens(
                        prefill_token_ids, request_prefill_logprobs, prefill_texts
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id,
                    next_token_logprob,
                    next_token_text,
                    next_token_id in self.all_special_ids,
                    generated_text,
734
735
                )

736
                generations.append(generation)
737

738
            new_input_length = input_length + 1
739

740
741
            # Update values
            batch.input_lengths[i] = new_input_length
742
743
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
744
            batch.all_input_ids[i] = all_input_ids
745
746
            batch.max_seqlen = batch.max_seqlen + 1
            cumulative_length += input_length
747
748
749

        # No need to return a batch if we know that all requests stopped
        return generations, batch if not stopped else None