flash_causal_lm.py 25.8 KB
Newer Older
1
2
3
import torch
import torch.distributed

4
5
import numpy as np

6
7
8
9
10
from torch.nn import functional as F

from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
11
from typing import Optional, Tuple, List, Type, Union, Dict
12
13
14
15
16
17
18
19
20

from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
21
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
22
23
24
25
26
27
28
29

tracer = trace.get_tracer(__name__)


@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
30
31
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
32
33

    # Decoder values
34
35
36
    input_ids: torch.Tensor
    position_ids: torch.Tensor

37
    # cumulative sequence lengths
38
39
40
41
42
    cu_seqlens: torch.Tensor
    # cumulative query sequence lengths, only used in decode
    cu_seqlens_q: Optional[torch.Tensor]
    # past key values, only used in decode
    past_key_values: Optional[torch.Tensor]
43
44
45
46
    max_seqlen: int

    # All tokens
    all_input_ids: List[List[int]]
47
    all_input_ids_tensor: torch.Tensor
48
49
50

    # Lengths of all generations present in the batch
    input_lengths: List[int]
51
52
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
53
54

    # Generation helpers
55
    next_token_chooser: HeterogeneousNextTokenChooser
56
57
    stopping_criterias: List[StoppingCriteria]

58
59
60
    # Maximum number of tokens this batch will grow to
    max_tokens: int

61
62
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
63
            id=self.batch_id,
64
            request_ids=[r.id for r in self.requests],
65
66
            size=len(self),
            max_tokens=self.max_tokens,
67
68
69
70
71
72
73
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
74
        dtype: torch.dtype,
75
        device: torch.device,
76
    ) -> "FlashCausalLMBatch":
77
78
79
80
81
        position_ids = []
        cu_seqlens = [0]
        max_seqlen = 0

        input_lengths = []
82
83
        prefix_offsets = []
        read_offsets = []
84
        all_input_ids = []
85
        requests_idx_mapping = {}
86

87
        next_token_chooser_parameters = []
88
89
90
91
92
        stopping_criterias = []

        # Cumulative length
        cumulative_length = 0

93
        max_tokens = 0
94
        max_length = 0
95

96
        # Parse batch
97
98
99
100
        for i, r in enumerate(pb.requests):
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

101
102
103
            tokenized_input = tokenizer(
                r.inputs, truncation=True, max_length=r.truncate
            )["input_ids"]
104

105
106
107
            input_length = len(tokenized_input)
            max_seqlen = max(max_seqlen, input_length)
            input_lengths.append(input_length)
108

109
110
            prefix_offsets.append(0)
            read_offsets.append(input_length)
111

112
            all_input_ids.append(tokenized_input)
113
114

            # Position ids
115
            position_ids.append(np.arange(0, input_length))
116
117
118
119

            # Add cumulative lengths of all previous inputs
            cu_seqlens.append(cumulative_length + input_length)

120
            next_token_chooser_parameters.append(r.parameters)
121

122
123
124
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
125
            max_new_tokens = stopping_criteria.max_new_tokens
126
            stopping_criterias.append(stopping_criteria)
127

128
129
            # Update
            cumulative_length += input_length
130
            max_tokens += input_length + max_new_tokens
131
132
133
134
135
136
137
138
139
140
141
142
            max_length = max(max_length, input_length + max_new_tokens)

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
            next_token_chooser_parameters, dtype, device
        )

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
143

144
145
146
147
        # Create tensors on device
        input_ids = torch.tensor(
            np.concatenate(all_input_ids), dtype=torch.int64, device=device
        )
148
149
150
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )
151
152
153
154
155
        position_ids = torch.tensor(
            np.concatenate(position_ids), dtype=torch.int32, device=device
        )
        cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)

156
157
158
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
159
            requests_idx_mapping=requests_idx_mapping,
160
161
162
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
163
            cu_seqlens_q=None,
164
165
166
            max_seqlen=max_seqlen,
            past_key_values=None,
            input_lengths=input_lengths,
167
168
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
169
            all_input_ids=all_input_ids,
170
171
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
172
            stopping_criterias=stopping_criterias,
173
            max_tokens=max_tokens,
174
175
        )

176
    @tracer.start_as_current_span("filter")
177
178
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
179
180
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
181
        if len(request_ids) == len(self):
182
183
            return self

184
        single_request = len(request_ids) == 1
185

186
187
188
189
190
191
        # Cumulative length
        cumulative_length = 0

        # New values after filtering
        requests_idx_mapping = {}

192
193
194
        # Used to index into tensors
        indices = []

195
        # Create on CPU to only move to GPU once instead of at every copy
196
        cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32)
197
        cu_seqlens_q = self.cu_seqlens_q[: len(request_ids) + 1]
198
199
200
        max_seqlen = 0
        past_key_values = []

201
        requests = []
202
203
        all_input_ids = []

204
        input_lengths = []
205
206
        prefix_offsets = []
        read_offsets = []
207

208
209
        stopping_criterias = []

210
211
        max_tokens = 0

212
213
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
214
            indices.append(idx)
215
216
217
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
218
219
220
221

            # Get length
            request_input_length = self.input_lengths[idx]

222
223
            # Copy to tensor (CPU)
            cu_seqlens[i + 1] = cumulative_length + request_input_length
224
            max_seqlen = max(max_seqlen, request_input_length)
225

226
227
228
229
            # Slice from past
            past_key_values.append(
                self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]]
            )
230
231
232
233

            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
234
235
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
236

237
238
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
239
240

            cumulative_length += request_input_length
241
242
243
            max_tokens += request_input_length + (
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
244

245
246
        if single_request:
            # Preallocate tensor for bs = 1 case
247
            past_key_values = F.pad(
248
                past_key_values[0],
249
250
251
252
253
254
255
256
257
258
259
                (
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    stopping_criterias[0].max_new_tokens
                    - stopping_criterias[0].current_tokens,
                ),
260
            )
261
262
263
264
        else:
            # Cat all past
            past_key_values = torch.cat(past_key_values, dim=1)

265
266
267
268
269
270
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
        next_token_chooser = self.next_token_chooser.filter(indices)

271
272
        # Move to GPU now that we have the whole tensor
        cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
273

274
275
276
277
278
279
280
        return FlashCausalLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
281
            cu_seqlens_q=cu_seqlens_q,
282
283
284
            max_seqlen=max_seqlen,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
285
286
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
287
288
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
289
            next_token_chooser=next_token_chooser,
290
            stopping_criterias=stopping_criterias,
291
            max_tokens=max_tokens,
292
293
294
295
296
297
298
299
300
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

301
302
        total_batch_size = sum([len(b) for b in batches])

303
        dtype = batches[0].past_key_values.dtype
304
305
306
307
        device = batches[0].input_ids.device

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
308
        cu_seqlens = [0]
309
310
311
        cu_seqlens_q = torch.arange(
            0, total_batch_size + 1, device=device, dtype=torch.int32
        )
312
313
314
        max_seqlen = 0
        past_key_values = []

315
316
317
        all_input_ids = []

        input_lengths = []
318
319
        prefix_offsets = []
        read_offsets = []
320

321
        next_token_chooser_parameters = []
322
323
        stopping_criterias = []

324
        # Cumulative length
325
326
        cumulative_batch_size = 0
        cumulative_length = 0
327
        max_tokens = 0
328
        max_length = 0
329
330
331

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
332
333
334
335
336
337
338
339

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

340
341
342
343
344
345
346
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids

347
348
349
            # Add cumulative lengths of all previous inputs
            cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
            max_seqlen = max(max_seqlen, batch.max_seqlen)
350

351
            if len(batch) != 1:
352
                past_key_values.append(batch.past_key_values)
353
            else:
354
355
356
357
358
                # past was pre-allocated for this batch
                # We need to slice to remove the padding
                past_key_values.append(
                    batch.past_key_values[:, : batch.input_lengths[0]]
                )
359
360
361

            all_input_ids.extend(batch.all_input_ids)

362
            input_lengths.extend(batch.input_lengths)
363
364
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
365

366
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
367
368
369
370
            stopping_criterias.extend(batch.stopping_criterias)

            # Update
            cumulative_length += batch.cu_seqlens[-1]
371
            cumulative_batch_size += len(batch)
372
            max_tokens += batch.max_tokens
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        batch.input_lengths, batch.stopping_criterias
                    )
                ),
            )

        all_input_ids_tensor = torch.zeros(
            (total_batch_size, max_length), dtype=torch.int64, device=device
        )

        cumulative_batch_size = 0
        for i, batch in enumerate(batches):
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)

            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]

            cumulative_batch_size += len(batch)
399

400
401
402
403
404
        # Cat past
        past_key_values = torch.cat(past_key_values, dim=1)
        # Create final tensor on GPU
        cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)

405
406
407
408
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
            next_token_chooser_parameters, dtype=dtype, device=device
        )

409
410
411
        return FlashCausalLMBatch(
            batch_id=batches[0].batch_id,
            requests=requests,
412
            requests_idx_mapping=requests_idx_mapping,
413
414
415
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
416
            cu_seqlens_q=cu_seqlens_q,
417
418
419
            max_seqlen=max_seqlen,
            past_key_values=past_key_values,
            input_lengths=input_lengths,
420
421
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
422
423
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
424
            next_token_chooser=next_token_chooser,
425
            stopping_criterias=stopping_criterias,
426
            max_tokens=max_tokens,
427
428
429
430
431
432
433
434
435
436
437
438
        )

    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
        model_cls: Type[PreTrainedModel],
        model_id: str,
        revision: Optional[str] = None,
439
        quantize: Optional[str] = None,
440
        trust_remote_code: bool = False,
441
442
443
    ):
        if torch.cuda.is_available():
            device = torch.device("cuda")
444
            dtype = torch.float16
445
446
447
448
        else:
            raise NotImplementedError("FlashCausalLM is only available on GPU")

        tokenizer = AutoTokenizer.from_pretrained(
449
450
451
452
453
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
454
        )
455
456
457
458
459
        model = model_cls.from_pretrained(
            model_id,
            revision=revision,
            torch_dtype=dtype,
            load_in_8bit=quantize == "bitsandbytes",
460
            trust_remote_code=trust_remote_code,
461
        ).to(device)
462
463

        super(FlashCausalLM, self).__init__(
464
            model=model,
465
466
467
468
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
469
470
471
472
473
474
475
476
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

    def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
        return self.tokenizer.decode(
477
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
478
479
480
481
482
483
484
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        cu_seqlens: torch.Tensor,
485
        cu_seqlens_q: Optional[torch.Tensor],
486
487
        max_s: int,
        past_key_values: Optional = None,
488
        pre_allocate_past_size: Optional[int] = None,
489
490
491
492
493
494
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
            cu_seqlens=cu_seqlens,
495
            cu_seqlens_q=cu_seqlens_q,
496
497
            max_s=max_s,
            past_key_values=past_key_values,
498
            pre_allocate_past_size=pre_allocate_past_size,
499
500
501
502
503
504
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
505
        prefill = batch.past_key_values is None
506
        single_request = len(batch) == 1
507

508
        if prefill and len(batch) == 1:
509
510
511
512
513
514
515
516
            # Ask to pre-allocate kv to its max size
            # == number of tokens + max_new_tokens
            pre_allocate_past_size = (
                batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
            )
        else:
            pre_allocate_past_size = None

517
        out, present = self.forward(
518
519
520
521
            batch.input_ids,
            batch.position_ids,
            batch.cu_seqlens,
            batch.cu_seqlens_q,
522
            batch.max_seqlen,
523
            batch.past_key_values,
524
            pre_allocate_past_size,
525
526
        )

527
528
529
530
531
532
533
534
535
536
537
        if prefill:
            next_token_logits = (
                out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1]
            )
        else:
            next_token_logits = out

        next_input_ids, next_token_logprobs = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
        )

538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
        if prefill:
            if len(batch) > 1:
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
                prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids))

            # Create batch.cu_seqlens_q for decode
            batch.cu_seqlens_q = torch.arange(
                0, len(batch) + 1, device=self.device, dtype=torch.int32
            )
            next_position_ids = batch.position_ids.new_empty(len(batch))
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

        # Prepare past for next decode
        if len(batch) > 1:
            # Used to slice next batch past
            past_indices = torch.empty(
                present.shape[1], dtype=torch.int64, device=self.device
            )
            batch.past_key_values = present.new_empty(
                (
                    present.shape[0],
                    present.shape[1] + len(batch.requests),
                    *present.shape[2:],
564
                )
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
            )

            # It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow
            # and will run asynchronously while we do the next for loop
            cumulative_length = 0
            for i, input_length in enumerate(batch.input_lengths):
                # Indexing metadata
                start_index = cumulative_length
                end_index = cumulative_length + input_length

                # Indices to copy present at the correct place in past_key_values
                torch.arange(
                    start_index + i,
                    end_index + i,
                    dtype=torch.int64,
                    device=self.device,
                    out=past_indices[start_index:end_index],
                )
                cumulative_length += input_length

            # Copy from present to past_key_values
            batch.past_key_values[:, past_indices] = present

        # Initialize past_key_values in prefill for len(batch) == 1
        elif prefill:
            # present is already pre-padded
            batch.past_key_values = present
592
593
594
595
596
597

        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
598
        stopped = True
599
600
601
602
603
604
605
606

        # Zipped iterator
        iterator = zip(
            batch.input_lengths,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

607
608
609
610
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

611
612
613
614
615
616
617
618
619
620
        # For each member of the batch
        for i, (
            input_length,
            stopping_criteria,
            all_input_ids,
        ) in enumerate(iterator):
            # Indexing metadata
            start_index = cumulative_length
            end_index = cumulative_length + input_length

621
            if prefill:
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
                if len(batch) > 1:
                    prefill_tokens_indices[
                        start_index : end_index - 1
                    ] = batch.input_ids[start_index + 1 : end_index]
                else:
                    # Set prefill_tokens_indices to the correct slice
                    prefill_tokens_indices = batch.input_ids[
                        start_index + 1 : end_index
                    ]

638
            batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665

            cumulative_length += input_length

        # Set values in batch
        batch.input_ids = next_input_ids
        batch.position_ids = next_position_ids + 1
        batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q

        if prefill:
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
        next_token_ids = batch.input_ids.tolist()

        cumulative_length = 0

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
666
667
            batch.prefix_offsets,
            batch.read_offsets,
668
669
670
            batch.stopping_criterias,
            batch.all_input_ids,
            batch.all_input_ids_tensor,
671
672
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
673
674
675
676
677
678
679
680
            next_token_ids,
            next_token_logprobs,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
681
682
            prefix_offset,
            read_offset,
683
684
685
            stopping_criteria,
            all_input_ids,
            all_input_ids_tensor,
686
687
            do_sample,
            seed,
688
689
690
691
692
            next_token_id,
            next_token_logprob,
        ) in enumerate(iterator):
            start_index = cumulative_length
            end_index = cumulative_length + input_length
693
694

            # Append next token to all tokens
695
            all_input_ids.append(next_token_id)
696
697

            # Generated token
698
            next_token_text, prefix_offset, read_offset = self.decode_token(
699
                all_input_ids,
700
701
                prefix_offset,
                read_offset,
702
703
704
705
            )

            # Evaluate stopping criteria
            stop, reason = stopping_criteria(
706
                next_token_id,
707
708
709
                next_token_text,
            )

710
            if not stop:
711
                stopped = False
712

713
714
715
716
717
718
719
720
721
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
                    output_text = self.decode(
                        all_input_ids[-stopping_criteria.current_tokens :]
                    )
                    generated_text = GeneratedText(
722
723
724
725
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
                    )
                else:
                    generated_text = None

                # Prefill
                if prefill:
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
                        start_index : end_index - 1
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    prefill_tokens = PrefillTokens(
                        prefill_token_ids, request_prefill_logprobs, prefill_texts
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id,
                    next_token_logprob,
                    next_token_text,
                    next_token_id in self.all_special_ids,
                    generated_text,
756
757
                )

758
                generations.append(generation)
759

760
            new_input_length = input_length + 1
761

762
763
            # Update values
            batch.input_lengths[i] = new_input_length
764
765
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
766
            batch.all_input_ids[i] = all_input_ids
767
            cumulative_length += input_length
768

769
770
        batch.max_seqlen = batch.max_seqlen + 1

771
772
        # No need to return a batch if we know that all requests stopped
        return generations, batch if not stopped else None