causal_lm.py 24.7 KB
Newer Older
1
import torch
2
import inspect
3

4
from dataclasses import dataclass
5
from opentelemetry import trace
6
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
7
from typing import Optional, Tuple, List, Type, Dict
8

9
10
11
12
13
14
15
16
17
from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
18

19
20
tracer = trace.get_tracer(__name__)

21
22

@dataclass
23
class CausalLMBatch(Batch):
24
25
    batch_id: int
    requests: List[generate_pb2.Request]
26
    requests_idx_mapping: Dict[int, int]
OlivierDehaene's avatar
OlivierDehaene committed
27
28
29
30

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
31
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
32
33
34
    past_key_values: Optional[List[Tuple]]

    # All tokens
35
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
36
37
38

    # Lengths of all generations present in the batch
    input_lengths: List[int]
39
40
    prefix_offsets: List[int]
    read_offsets: List[int]
OlivierDehaene's avatar
OlivierDehaene committed
41
42

    # Generation helpers
43
44
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
OlivierDehaene's avatar
OlivierDehaene committed
45
46

    # Metadata used for padding
47
    max_input_length: int
48
    padding_right_offset: int
49

50
51
52
    # Maximum number of tokens this batch will grow to
    max_tokens: int

53
54
55
    # Past metadata
    keys_head_dim_last: bool = True

56
57
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
58
            id=self.batch_id,
59
            request_ids=[r.id for r in self.requests],
60
            size=len(self),
61
            max_tokens=self.max_tokens,
62
63
64
65
        )

    @classmethod
    def from_pb(
66
67
68
69
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device,
70
71
72
73
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
74
75
        prefix_offsets = []
        read_offsets = []
76
        requests_idx_mapping = {}
77
78

        # Parse batch
79
        max_truncation = 0
80
        padding_right_offset = 0
81
        max_decode_tokens = 0
82
83
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
84
            inputs.append(r.inputs)
85
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
86
87
88
89
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
90
            max_truncation = max(max_truncation, r.truncate)
91
            max_decode_tokens += stopping_criteria.max_new_tokens
92
93
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
94
95
            )

OlivierDehaene's avatar
OlivierDehaene committed
96
        tokenized_inputs = tokenizer(
97
98
99
            inputs,
            return_tensors="pt",
            padding=True,
100
            return_token_type_ids=False,
101
102
            truncation=True,
            max_length=max_truncation,
103
        ).to(device)
104
105
106
107
        for _ in pb.requests:
            input_len = tokenized_inputs["input_ids"].shape[1]
            prefix_offsets.append(0)
            read_offsets.append(input_len)
108

109
110
111
        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

112
113
114
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
115
            (pb.size, max_input_length + padding_right_offset)
116
117
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
118
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
119

120
121
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
122
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
123

124
125
        max_tokens = len(inputs) * max_input_length + max_decode_tokens

126
127
128
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
129
            requests_idx_mapping=requests_idx_mapping,
130
131
            input_ids=input_ids,
            attention_mask=attention_mask,
132
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
133
            past_key_values=None,
134
            all_input_ids=list(all_input_ids),
135
            input_lengths=input_lengths.tolist(),
136
137
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
138
139
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
140
            max_input_length=max_input_length.item(),
141
            padding_right_offset=padding_right_offset,
142
            max_tokens=max_tokens,
143
144
        )

145
    @tracer.start_as_current_span("filter")
146
147
    def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
        if len(request_ids) == 0:
148
            raise ValueError("Batch must have at least one request")
149
        if len(request_ids) == len(self):
150
151
152
153
154
155
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
156
        requests = []
157
        input_lengths = []
158
159
        prefix_offsets = []
        read_offsets = []
160
161
162
163
164
165
        all_input_ids = []
        max_input_length = 0

        next_token_choosers = []
        stopping_criterias = []

166
        total_remaining_decode_tokens = 0
167
168
        new_padding_right_offset = 0

169
170
171
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
            requests_idx_mapping[request_id] = i
172
173
            keep_indices.append(idx)

174
            requests.append(self.requests[idx])
175
176
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
177
178
179
180
181
182
183
            all_input_ids.append(self.all_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            next_token_choosers.append(self.next_token_choosers[idx])
184
185
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
186
            remaining_decode_tokens = (
187
188
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
189
190
191
192
            total_remaining_decode_tokens += remaining_decode_tokens
            new_padding_right_offset = max(
                new_padding_right_offset, remaining_decode_tokens
            )
193
194
195
196

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
        input_ids = self.input_ids[keep_indices]
        position_ids = self.position_ids[keep_indices]
197
198
        self.attention_mask = self.attention_mask[
            keep_indices,
199
200
201
202
            -(self.padding_right_offset + max_input_length) : (
                self.attention_mask.shape[1] - self.padding_right_offset
            )
            + new_padding_right_offset,
203
204
        ]

205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
            self.past_key_values = [list(layer) for layer in self.past_key_values]

        # Update tensors in-place to allow incremental garbage collection
        past_kv_length = max_input_length - 1
        for layer in self.past_key_values:
            past_keys, past_values = layer
            if len(past_keys.shape) == 3:
                # Force past to be of dim [self_size, num_heads, ...] for easy indexing
                past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
                past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
            if self.keys_head_dim_last:
                layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
            else:
                layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
            del past_keys
            layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
            del past_values

225
        max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
226

227
228
229
230
231
232
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = input_ids
        self.position_ids = position_ids
        self.all_input_ids = all_input_ids
        self.input_lengths = input_lengths
233
234
        self.prefix_offsets = prefix_offsets
        self.read_offsets = read_offsets
235
236
237
238
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
        self.max_input_length = max_input_length
        self.padding_right_offset = new_padding_right_offset
239
        self.max_tokens = max_tokens
240
241

        return self
242

243
    @classmethod
244
    @tracer.start_as_current_span("concatenate")
245
246
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
247
        total_batch_size = 0
248
        max_input_length = 0
249
250
        padding_right_offset = 0
        for batch in batches:
251
            total_batch_size += len(batch)
252
            max_input_length = max(max_input_length, batch.max_input_length)
253
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
254
255
256

        # Batch attributes
        requests = []
257
        requests_idx_mapping = {}
OlivierDehaene's avatar
OlivierDehaene committed
258
        input_lengths = []
259
260
        prefix_offsets = []
        read_offsets = []
261
262
263
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []
264
        max_tokens = 0
265

OlivierDehaene's avatar
OlivierDehaene committed
266
267
268
        # Batch tensors
        input_ids = None
        attention_mask = None
269
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
270
271
        past_key_values = []

272
273
274
275
276
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
277
            input_lengths.extend(batch.input_lengths)
278
279
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
280
281
282
283
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

284
285
286
287
288
289
290
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

291
            # Slicing end index for this batch
292
            end_index = start_index + len(batch)
293
294

            # We only concatenate batches that did at least one step
295
296
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
297

OlivierDehaene's avatar
OlivierDehaene committed
298
299
300
301
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
302
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
303
304
305
306
307
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
308
                attention_mask = batch.attention_mask.new_zeros(
309
                    (total_batch_size, max_input_length + padding_right_offset),
310
311
312
                )

            # We need to slice the attention mask to remove padding from previous steps
313
            # and to remove unused allocated space
314
            left_offset = max_input_length - batch.max_input_length
315
            batch_left_offset = (
316
                batch.attention_mask.shape[1]
317
                - batch.max_input_length
318
                - batch.padding_right_offset
319
            )
OlivierDehaene's avatar
OlivierDehaene committed
320
            attention_mask[
321
322
323
324
325
326
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
327

328
329
330
331
332
333
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

334
335
336
337
338
339
            # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
            # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
            # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
            # And ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
                batch.past_key_values = [
340
341
                    [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
                    for layer in batch.past_key_values
342
                ]
343
            elif len(batch.past_key_values[0][0].shape) == 3:
344
345
346
347
                for layer in batch.past_key_values:
                    for k, t in enumerate(layer):
                        layer[k] = t.view(len(batch), -1, *t.shape[-2:])

348
349
350
351
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length - batch.max_input_length
            ) * len(batch)
352

353
354
            start_index = end_index

355
356
357
358
359
360
361
362
363
        first_past_kvs = batches[0].past_key_values
        _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape

        padded_past_values_shape = (
            total_batch_size,
            num_heads,
            max_input_length - 1,
            head_dim,
        )
364

365
366
367
368
369
370
371
372
373
374
        if batches[0].keys_head_dim_last:
            padded_past_keys_shape = padded_past_values_shape
        else:
            # seq_length is last for BLOOM
            padded_past_keys_shape = (
                total_batch_size,
                num_heads,
                head_dim,
                max_input_length - 1,
            )
375

376
377
378
379
380
381
382
383
384
385
386
387
388
389
        # Iterate over attention layers
        # Concatenate past key values layer by layer to allow incremental garbage collection
        for j in range(len(first_past_kvs)):
            padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
            start_index = 0
            for batch in batches:
                past_keys = batch.past_key_values[j][0]
                # Clear reference to the original tensor
                batch.past_key_values[j][0] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the keys to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
390
                if batch.keys_head_dim_last:
391
392
393
                    padded_past_keys[
                        start_index:end_index, :, -past_seq_len:, :
                    ] = past_keys[:, :, -past_seq_len:, :]
394
                else:
395
396
397
398
399
400
401
402
                    # BLOOM case
                    padded_past_keys[
                        start_index:end_index, :, :, -past_seq_len:
                    ] = past_keys[:, :, :, -past_seq_len:]
                del past_keys

                start_index = end_index

403
404
405
            padded_past_values = first_past_kvs[j][1].new_zeros(
                padded_past_values_shape
            )
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
            start_index = 0
            for batch in batches:
                past_values = batch.past_key_values[j][1]
                # Clear reference to the original tensor
                batch.past_key_values[j][1] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the past values to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
                padded_past_values[
                    start_index:end_index, :, -past_seq_len:, :
                ] = past_values[:, :, -past_seq_len:, :]
                del past_values

421
                # Update values
422
423
424
                start_index = end_index

            past_key_values.append([padded_past_keys, padded_past_values])
425
426
427
428

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
429
            requests_idx_mapping=requests_idx_mapping,
430
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
431
            attention_mask=attention_mask,
432
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
433
            past_key_values=past_key_values,
434
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
435
            input_lengths=input_lengths,
436
437
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
438
439
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
440
            max_input_length=max_input_length,
441
            padding_right_offset=padding_right_offset,
442
            keys_head_dim_last=batches[0].keys_head_dim_last,
443
            max_tokens=max_tokens,
444
        )
445

446
447
448
    def __len__(self):
        return len(self.requests)

449
450

class CausalLM(Model):
451
452
453
454
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
455
        quantize: Optional[str] = None,
456
        trust_remote_code: bool = False,
457
    ):
458
459
        if torch.cuda.is_available():
            device = torch.device("cuda")
460
            dtype = torch.float16
461
        else:
462
463
464
            if quantize:
                raise ValueError("quantization is not available on CPU")

465
466
467
            device = torch.device("cpu")
            dtype = torch.float32

468
        tokenizer = AutoTokenizer.from_pretrained(
469
470
471
472
473
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
474
        )
475
        model = AutoModelForCausalLM.from_pretrained(
476
            model_id,
477
            revision=revision,
478
            torch_dtype=dtype,
479
480
481
            device_map="auto"
            if torch.cuda.is_available() and torch.cuda.device_count() > 1
            else None,
482
            load_in_8bit=quantize == "bitsandbytes",
483
            trust_remote_code=trust_remote_code,
484
        )
485
486
487
        if torch.cuda.is_available() and torch.cuda.device_count() == 1:
            model = model.cuda()

488
489
490
491
492
493
494
495
496
497
498
499
500
        if tokenizer.pad_token_id is None:
            if model.config.pad_token_id is not None:
                tokenizer.pad_token_id = model.config.pad_token_id
            elif model.config.eos_token_id is not None:
                tokenizer.pad_token_id = model.config.eos_token_id
            elif tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
            else:
                tokenizer.add_special_tokens({"pad_token": "[PAD]"})

        self.has_position_ids = (
            inspect.signature(model.forward).parameters.get("position_ids", None)
            is not None
501
        )
502

503
        super(CausalLM, self).__init__(
504
            model=model,
505
506
507
508
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
509
510
511
512
513
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
514

515
516
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
517
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
518
519
        )

520
    def forward(
521
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
522
523
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
524
525
526
527
528
529
530
531
532
533
534
        kwargs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "past_key_values": past_key_values,
            "use_cache": True,
            "return_dict": True,
        }
        if self.has_position_ids:
            kwargs["position_ids"] = position_ids

        outputs = self.model.forward(**kwargs)
535
        return outputs.logits, outputs.past_key_values
536

537
    @tracer.start_as_current_span("generate_token")
538
539
    def generate_token(
        self, batch: CausalLMBatch
540
    ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
541
542
543
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

544
545
        logits, past = self.forward(
            batch.input_ids,
546
            attention_mask,
547
548
            batch.position_ids,
            batch.past_key_values,
549
550
        )

551
552
        # Results
        generations: List[Generation] = []
553
        stopped = True
554
555
556
557

        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
558
            batch.input_lengths,
559
560
            batch.prefix_offsets,
            batch.read_offsets,
561
562
563
564
565
566
567
568
569
570
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
571
572
            prefix_offset,
            read_offset,
573
574
575
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
576
            all_input_ids,
577
578
        ) in enumerate(iterator):
            # Select next token
579
            next_token_id, logprobs = next_token_chooser(
580
                all_input_ids.view(1, -1), logits[-1:, :]
581
            )
582
583

            # Append next token to all tokens
584
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
585
586
            new_input_length = input_length + 1

587
588
589
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
590
591
            next_token_text, prefix_offset, read_offset = self.decode_token(
                all_input_ids[:, 0], prefix_offset, read_offset
592
            )
593
594

            # Evaluate stopping criteria
595
            stop, reason = stopping_criteria(
596
597
                next_token_id_squeezed,
                next_token_text,
598
            )
599

600
            if not stop:
601
                stopped = False
602

603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
                    output_text = self.decode(
                        all_input_ids[-stopping_criteria.current_tokens :, 0]
                    )
                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
                if stopping_criteria.current_tokens == 1:
                    # Remove generated token to only have prefill and add nan for first prompt token
                    prefill_logprobs = [float("nan")] + torch.log_softmax(
                        logits, -1
                    ).gather(1, all_input_ids[1:]).squeeze(1)[
                        -new_input_length:-1
                    ].tolist()
                    prefill_token_ids = all_input_ids[-new_input_length:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    prefill_tokens = PrefillTokens(
                        prefill_token_ids, prefill_logprobs, prefill_texts
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id_squeezed,
                    next_token_logprob,
                    next_token_text,
                    next_token_id_squeezed.item() in self.all_special_ids,
                    generated_text,
651
652
                )

653
                generations.append(generation)
654

655
656
657
658
            # Update values
            batch.input_ids[i, 0] = next_token_id
            batch.all_input_ids[i] = all_input_ids
            batch.input_lengths[i] = new_input_length
659
660
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
661
662
            batch.max_input_length = max(batch.max_input_length, new_input_length)

663
        # We finished all generations in the batch; there is no next batch
664
        if stopped:
665
            return generations, None
666

667
668
        # Slice unused values from prefill
        batch.input_ids = batch.input_ids[:, :1]
669

670
        # Update attention_mask as we added a new token to input_ids
671
672
673
        batch.attention_mask[:, -batch.padding_right_offset] = 1
        # Decrease right offset
        batch.padding_right_offset -= 1
674

675
        # Update position_ids
676
677
678
679
680
681
        batch.position_ids = batch.position_ids[:, -1:] + 1

        # Update past key values
        batch.past_key_values = past

        return generations, batch