flash_causal_lm.py 25.2 KB
Newer Older
1
2
3
import torch
import torch.distributed

4
5
import numpy as np

6
7
8
9
10
from torch.nn import functional as F

from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
11
from typing import Optional, Tuple, List, Type, Union, Dict
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
    NextTokenChooser,
    StoppingCriteria,
    Sampling,
)

tracer = trace.get_tracer(__name__)


@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
34
35
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
36
37

    # Decoder values
38
39
40
    input_ids: torch.Tensor
    position_ids: torch.Tensor

41
    # cumulative sequence lengths
42
43
44
45
46
    cu_seqlens: torch.Tensor
    # cumulative query sequence lengths, only used in decode
    cu_seqlens_q: Optional[torch.Tensor]
    # past key values, only used in decode
    past_key_values: Optional[torch.Tensor]
47
48
49
50
51
52
53
54
    max_seqlen: int

    # All tokens
    all_input_ids: List[List[int]]
    all_input_ids_tensor: List[torch.Tensor]

    # Lengths of all generations present in the batch
    input_lengths: List[int]
55
56
    offsets: List[Optional[int]]
    token_offsets: List[Optional[int]]
57
58
59
60
61

    # Generation helpers
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]

62
63
64
    # Maximum number of tokens this batch will grow to
    max_tokens: int

65
66
    def to_pb(self) -> generate_pb2.Batch:
        return generate_pb2.Batch(
67
68
69
70
            id=self.batch_id,
            requests=self.requests,
            size=len(self),
            max_tokens=self.max_tokens,
71
72
73
74
75
76
77
78
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
79
    ) -> "FlashCausalLMBatch":
80
81
82
83
84
        position_ids = []
        cu_seqlens = [0]
        max_seqlen = 0

        input_lengths = []
85
86
        offsets = []
        token_offsets = []
87
        all_input_ids = []
88
        requests_idx_mapping = {}
89
90
91
92
93
94
95

        next_token_choosers = []
        stopping_criterias = []

        # Cumulative length
        cumulative_length = 0

96
97
        max_tokens = 0

98
        # Parse batch
99
100
101
102
        for i, r in enumerate(pb.requests):
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

103
104
105
            tokenized_input = tokenizer(
                r.inputs, truncation=True, max_length=r.truncate
            )["input_ids"]
106

107
108
109
            input_length = len(tokenized_input)
            max_seqlen = max(max_seqlen, input_length)
            input_lengths.append(input_length)
110

111
112
            offsets.append(None)
            token_offsets.append(None)
113

114
            all_input_ids.append(tokenized_input)
115
116

            # Position ids
117
            position_ids.append(np.arange(0, input_length))
118
119
120
121
122

            # Add cumulative lengths of all previous inputs
            cu_seqlens.append(cumulative_length + input_length)

            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
123

124
125
126
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
127
            max_new_tokens = stopping_criteria.max_new_tokens
128
            stopping_criterias.append(stopping_criteria)
129

130
131
            # Update
            cumulative_length += input_length
132
            max_tokens += input_length + max_new_tokens
133

134
135
136
137
138
139
140
141
142
        # Create tensors on device
        input_ids = torch.tensor(
            np.concatenate(all_input_ids), dtype=torch.int64, device=device
        )
        position_ids = torch.tensor(
            np.concatenate(position_ids), dtype=torch.int32, device=device
        )
        cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)

143
144
145
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
146
            requests_idx_mapping=requests_idx_mapping,
147
148
149
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
150
            cu_seqlens_q=None,
151
152
153
            max_seqlen=max_seqlen,
            past_key_values=None,
            input_lengths=input_lengths,
154
155
            offsets=offsets,
            token_offsets=token_offsets,
156
            all_input_ids=all_input_ids,
157
            all_input_ids_tensor=[],
158
159
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
160
            max_tokens=max_tokens,
161
162
        )

163
164
165
166
167
168
169
170
    @tracer.start_as_current_span("filter")
    def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch":
        if len(requests) == 0:
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
        if len(requests) == len(self):
            return self

171
172
        single_request = len(requests) == 1

173
174
175
176
177
178
        # Cumulative length
        cumulative_length = 0

        # New values after filtering
        requests_idx_mapping = {}

179
180
181
182
183
184
185
        input_ids = self.input_ids.new_empty(len(requests))
        position_ids = self.position_ids.new_empty(len(requests))
        # Create on CPU to only move to GPU once instead of at every copy
        cu_seqlens = torch.zeros(len(requests) + 1, dtype=torch.int32)
        cu_seqlens_q = torch.arange(
            0, len(requests) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
        )
186
187
188
189
190
191
        max_seqlen = 0
        past_key_values = []

        all_input_ids = []
        all_input_ids_tensor = []

192
        input_lengths = []
193
194
        offsets = []
        token_offsets = []
195

196
197
198
        next_token_choosers = []
        stopping_criterias = []

199
200
        max_tokens = 0

201
202
203
204
205
206
207
        for i, r in enumerate(requests):
            idx = self.requests_idx_mapping[r.id]
            requests_idx_mapping[r.id] = i

            # Get length
            request_input_length = self.input_lengths[idx]

208
209
210
211
212
213
            # Copy tensors (GPU)
            input_ids[i] = self.input_ids[idx]
            position_ids[i] = self.position_ids[idx]

            # Copy to tensor (CPU)
            cu_seqlens[i + 1] = cumulative_length + request_input_length
214
            max_seqlen = max(max_seqlen, request_input_length)
215

216
217
218
219
            # Slice from past
            past_key_values.append(
                self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]]
            )
220
221
222
223
224
225
226
227
228

            all_input_ids.append(self.all_input_ids[idx])
            all_input_ids_tensor.append(self.all_input_ids_tensor[idx])

            input_lengths.append(request_input_length)
            offsets.append(self.offsets[idx])
            token_offsets.append(self.token_offsets[idx])

            next_token_choosers.append(self.next_token_choosers[idx])
229
230
231

            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
232
233

            cumulative_length += request_input_length
234
235
236
            max_tokens += request_input_length + (
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
237

238
239
        if single_request:
            # Preallocate tensor for bs = 1 case
240
            past_key_values = F.pad(
241
                past_key_values[0],
242
243
244
245
246
247
248
249
250
251
252
                (
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    stopping_criterias[0].max_new_tokens
                    - stopping_criterias[0].current_tokens,
                ),
253
            )
254
255
256
257
258
259
        else:
            # Cat all past
            past_key_values = torch.cat(past_key_values, dim=1)

        # Move to GPU now that we have the whole tensor
        cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
260

261
262
263
264
265
266
267
        return FlashCausalLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
268
            cu_seqlens_q=cu_seqlens_q,
269
270
271
272
273
274
275
276
277
            max_seqlen=max_seqlen,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
            offsets=offsets,
            token_offsets=token_offsets,
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
278
            max_tokens=max_tokens,
279
280
281
282
283
284
285
286
287
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

288
289
290
291
292
293
        total_batch_size = sum([len(b) for b in batches])

        device = batches[0].input_ids.device

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
294
        cu_seqlens = [0]
295
296
297
        cu_seqlens_q = torch.arange(
            0, total_batch_size + 1, device=device, dtype=torch.int32
        )
298
299
300
        max_seqlen = 0
        past_key_values = []

301
302
303
304
305
306
307
308
309
310
        all_input_ids = []
        all_input_ids_tensor = []

        input_lengths = []
        offsets = []
        token_offsets = []

        next_token_choosers = []
        stopping_criterias = []

311
        # Cumulative length
312
313
        cumulative_batch_size = 0
        cumulative_length = 0
314
        max_tokens = 0
315
316
317

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
318
319
320
321
322
323
324
325

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

326
327
328
329
330
331
332
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids

333
334
335
            # Add cumulative lengths of all previous inputs
            cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
            max_seqlen = max(max_seqlen, batch.max_seqlen)
336

337
            if len(batch) != 1:
338
                past_key_values.append(batch.past_key_values)
339
            else:
340
341
342
343
344
                # past was pre-allocated for this batch
                # We need to slice to remove the padding
                past_key_values.append(
                    batch.past_key_values[:, : batch.input_lengths[0]]
                )
345
346
347
348

            all_input_ids.extend(batch.all_input_ids)
            all_input_ids_tensor.extend(batch.all_input_ids_tensor)

349
            input_lengths.extend(batch.input_lengths)
350
351
            offsets.extend(batch.offsets)
            token_offsets.extend(batch.token_offsets)
352

353
354
355
356
357
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

            # Update
            cumulative_length += batch.cu_seqlens[-1]
358
            cumulative_batch_size += len(batch)
359
            max_tokens += batch.max_tokens
360

361
362
363
364
365
        # Cat past
        past_key_values = torch.cat(past_key_values, dim=1)
        # Create final tensor on GPU
        cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)

366
367
368
        return FlashCausalLMBatch(
            batch_id=batches[0].batch_id,
            requests=requests,
369
            requests_idx_mapping=requests_idx_mapping,
370
371
372
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
373
            cu_seqlens_q=cu_seqlens_q,
374
375
376
            max_seqlen=max_seqlen,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
377
378
            offsets=offsets,
            token_offsets=token_offsets,
379
380
381
382
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
383
            max_tokens=max_tokens,
384
385
386
387
388
389
390
391
392
393
394
395
        )

    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
        model_cls: Type[PreTrainedModel],
        model_id: str,
        revision: Optional[str] = None,
396
397
        quantize: bool = False,
        decode_buffer: int = 3,
398
399
400
    ):
        if torch.cuda.is_available():
            device = torch.device("cuda")
401
            dtype = torch.float16
402
403
404
405
        else:
            raise NotImplementedError("FlashCausalLM is only available on GPU")

        tokenizer = AutoTokenizer.from_pretrained(
406
            model_id, revision=revision, padding_side="left", truncation_side="left"
407
408
409
410
411
412
        )
        self.model = (
            model_cls.from_pretrained(
                model_id,
                revision=revision,
                torch_dtype=dtype,
413
                load_in_8bit=quantize,
414
415
            )
            .eval()
416
            .to(device)
417
418
419
        )

        super(FlashCausalLM, self).__init__(
420
421
422
423
424
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
            decode_buffer=decode_buffer,
425
426
427
428
429
430
431
432
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

    def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
        return self.tokenizer.decode(
433
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
434
435
436
437
438
439
440
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        cu_seqlens: torch.Tensor,
441
        cu_seqlens_q: Optional[torch.Tensor],
442
443
        max_s: int,
        past_key_values: Optional = None,
444
        pre_allocate_past_size: Optional[int] = None,
445
446
447
448
449
450
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
451
            cu_seqlens_q=cu_seqlens_q,
452
453
            max_s=max_s,
            past_key_values=past_key_values,
454
            pre_allocate_past_size=pre_allocate_past_size,
455
456
457
458
459
460
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
461
        prefill = batch.past_key_values is None
462

463
        if prefill and len(batch) == 1:
464
465
466
467
468
469
470
471
            # Ask to pre-allocate kv to its max size
            # == number of tokens + max_new_tokens
            pre_allocate_past_size = (
                batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
            )
        else:
            pre_allocate_past_size = None

472
        out, present = self.forward(
473
474
475
476
            batch.input_ids,
            batch.position_ids,
            batch.cu_seqlens,
            batch.cu_seqlens_q,
477
            batch.max_seqlen,
478
            batch.past_key_values,
479
            pre_allocate_past_size,
480
481
        )

482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
        if prefill:
            if len(batch) > 1:
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
                prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids))

            # Create batch.cu_seqlens_q for decode
            batch.cu_seqlens_q = torch.arange(
                0, len(batch) + 1, device=self.device, dtype=torch.int32
            )
            next_input_ids = batch.input_ids.new_empty(len(batch))
            next_position_ids = batch.position_ids.new_empty(len(batch))
        else:
            prefill_logprobs = None
            next_input_ids = batch.input_ids
            next_position_ids = batch.position_ids

        next_token_logprobs = out.new_empty(len(batch))

        # Prepare past for next decode
        if len(batch) > 1:
            # Used to slice next batch past
            past_indices = torch.empty(
                present.shape[1], dtype=torch.int64, device=self.device
            )
            batch.past_key_values = present.new_empty(
                (
                    present.shape[0],
                    present.shape[1] + len(batch.requests),
                    *present.shape[2:],
512
                )
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
            )

            # It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow
            # and will run asynchronously while we do the next for loop
            cumulative_length = 0
            for i, input_length in enumerate(batch.input_lengths):
                # Indexing metadata
                start_index = cumulative_length
                end_index = cumulative_length + input_length

                # Indices to copy present at the correct place in past_key_values
                torch.arange(
                    start_index + i,
                    end_index + i,
                    dtype=torch.int64,
                    device=self.device,
                    out=past_indices[start_index:end_index],
                )
                cumulative_length += input_length

            # Copy from present to past_key_values
            batch.past_key_values[:, past_indices] = present

        # Initialize past_key_values in prefill for len(batch) == 1
        elif prefill:
            # present is already pre-padded
            batch.past_key_values = present
540
541
542
543
544
545

        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
546
        stopped = True
547
548
549
550
551
552
553
554
555

        # Zipped iterator
        iterator = zip(
            batch.input_lengths,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

556
557
558
559
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

560
561
562
563
564
565
566
567
568
569
570
        # For each member of the batch
        for i, (
            input_length,
            next_token_chooser,
            stopping_criteria,
            all_input_ids,
        ) in enumerate(iterator):
            # Indexing metadata
            start_index = cumulative_length
            end_index = cumulative_length + input_length

571
            if prefill:
572
573
                # Prefill mode
                # out is of shape [cumulative_sequence_lengths, vocab_size]
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
                # only take last token logit
                logits = out[end_index - 1 : end_index]

                # Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty)
                all_input_ids_tensor = batch.input_ids.new_empty(
                    input_length + stopping_criteria.max_new_tokens
                )
                # Copy from batch.input_ids to all_input_ids_tensor
                all_input_ids_tensor[:input_length] = batch.input_ids[
                    start_index:end_index
                ]
                batch.all_input_ids_tensor.append(all_input_ids_tensor)

                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
                if len(batch) > 1:
                    prefill_tokens_indices[
                        start_index : end_index - 1
                    ] = batch.input_ids[start_index + 1 : end_index]
                else:
                    # Set prefill_tokens_indices to the correct slice
                    prefill_tokens_indices = batch.input_ids[
                        start_index + 1 : end_index
                    ]
602
603
604
            else:
                # Decode mode
                # out is of shape [batch_size, vocab_size]
605
606
607
                logits = out[i].view(1, -1)

            all_input_ids_tensor = batch.all_input_ids_tensor[i]
608
609

            # Select next token
610
            next_token_id, logprob = next_token_chooser(
611
612
                all_input_ids_tensor[None, :input_length], logits
            )
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672

            # Add to all_input_ids_tensor
            next_token_id_squeezed = next_token_id.view(1)
            all_input_ids_tensor[input_length] = next_token_id_squeezed

            # Set values
            next_input_ids[i] = next_token_id_squeezed
            next_token_logprobs[i] = logprob[-1, next_token_id].view(1)

            cumulative_length += input_length

        # Set values in batch
        batch.input_ids = next_input_ids
        batch.position_ids = next_position_ids + 1
        batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q

        if prefill:
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
        next_token_ids = batch.input_ids.tolist()

        cumulative_length = 0

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
            batch.offsets,
            batch.token_offsets,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
            batch.all_input_ids_tensor,
            next_token_ids,
            next_token_logprobs,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
            offset,
            token_offset,
            next_token_chooser,
            stopping_criteria,
            all_input_ids,
            all_input_ids_tensor,
            next_token_id,
            next_token_logprob,
        ) in enumerate(iterator):
            start_index = cumulative_length
            end_index = cumulative_length + input_length
673
674

            # Append next token to all tokens
675
            all_input_ids.append(next_token_id)
676
677

            # Generated token
678
679
680
681
            next_token_text, offset, token_offset = self.decode_token(
                all_input_ids,
                offset,
                token_offset,
682
683
684
685
            )

            # Evaluate stopping criteria
            stop, reason = stopping_criteria(
686
                next_token_id,
687
688
689
                next_token_text,
            )

690
            if not stop:
691
                stopped = False
692

693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
                    output_text = self.decode(
                        all_input_ids[-stopping_criteria.current_tokens :]
                    )
                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
                if prefill:
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
                        start_index : end_index - 1
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    prefill_tokens = PrefillTokens(
                        prefill_token_ids, request_prefill_logprobs, prefill_texts
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id,
                    next_token_logprob,
                    next_token_text,
                    next_token_id in self.all_special_ids,
                    generated_text,
739
740
                )

741
                generations.append(generation)
742

743
            new_input_length = input_length + 1
744

745
746
747
748
749
            # Update values
            batch.input_lengths[i] = new_input_length
            batch.offsets[i] = offset
            batch.token_offsets[i] = token_offset
            batch.all_input_ids[i] = all_input_ids
750
751
            batch.max_seqlen = batch.max_seqlen + 1
            cumulative_length += input_length
752
753
754

        # No need to return a batch if we know that all requests stopped
        return generations, batch if not stopped else None