causal_lm.py 24.7 KB
Newer Older
1
import torch
2
import inspect
3

4
from dataclasses import dataclass
5
from opentelemetry import trace
6
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
7
from typing import Optional, Tuple, List, Type, Dict
8

9
10
11
12
13
14
15
16
17
from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
18

19
20
tracer = trace.get_tracer(__name__)

21
22

@dataclass
23
class CausalLMBatch(Batch):
24
25
    batch_id: int
    requests: List[generate_pb2.Request]
26
    requests_idx_mapping: Dict[int, int]
OlivierDehaene's avatar
OlivierDehaene committed
27
28
29
30

    # Decoder values
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
31
    position_ids: torch.Tensor
OlivierDehaene's avatar
OlivierDehaene committed
32
33
34
    past_key_values: Optional[List[Tuple]]

    # All tokens
35
    all_input_ids: List[torch.Tensor]
OlivierDehaene's avatar
OlivierDehaene committed
36
37
38

    # Lengths of all generations present in the batch
    input_lengths: List[int]
39
40
    prefix_offsets: List[int]
    read_offsets: List[int]
OlivierDehaene's avatar
OlivierDehaene committed
41
42

    # Generation helpers
43
44
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
OlivierDehaene's avatar
OlivierDehaene committed
45
46

    # Metadata used for padding
47
    max_input_length: int
48
    padding_right_offset: int
49

50
51
52
    # Maximum number of tokens this batch will grow to
    max_tokens: int

53
54
55
    # Past metadata
    keys_head_dim_last: bool = True

56
57
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
58
            id=self.batch_id,
59
            request_ids=[r.id for r in self.requests],
60
            size=len(self),
61
            max_tokens=self.max_tokens,
62
63
64
65
        )

    @classmethod
    def from_pb(
66
67
68
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
69
        dtype: torch.dtype,
70
        device: torch.device,
71
72
73
74
    ) -> "CausalLMBatch":
        inputs = []
        next_token_choosers = []
        stopping_criterias = []
75
76
        prefix_offsets = []
        read_offsets = []
77
        requests_idx_mapping = {}
78
79

        # Parse batch
80
        max_truncation = 0
81
        padding_right_offset = 0
82
        max_decode_tokens = 0
83
84
        for i, r in enumerate(pb.requests):
            requests_idx_mapping[r.id] = i
85
            inputs.append(r.inputs)
86
            next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
87
88
89
90
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
            stopping_criterias.append(stopping_criteria)
91
            max_truncation = max(max_truncation, r.truncate)
92
            max_decode_tokens += stopping_criteria.max_new_tokens
93
94
            padding_right_offset = max(
                padding_right_offset, stopping_criteria.max_new_tokens
95
96
            )

OlivierDehaene's avatar
OlivierDehaene committed
97
        tokenized_inputs = tokenizer(
98
99
100
            inputs,
            return_tensors="pt",
            padding=True,
101
            return_token_type_ids=False,
102
103
            truncation=True,
            max_length=max_truncation,
104
        ).to(device)
105
106
107
108
        for _ in pb.requests:
            input_len = tokenized_inputs["input_ids"].shape[1]
            prefix_offsets.append(0)
            read_offsets.append(input_len)
109

110
111
112
        input_lengths = tokenized_inputs["attention_mask"].sum(1)
        max_input_length = input_lengths.max()

113
114
115
        input_ids = tokenized_inputs["input_ids"]
        # Allocate maximum attention_mask
        attention_mask = input_ids.new_zeros(
116
            (pb.size, max_input_length + padding_right_offset)
117
118
        )
        # Copy tokenizer attention_mask into fully allocated attention_mask
119
        attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
120

121
122
        position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
        position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
123
        all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
124

125
126
        max_tokens = len(inputs) * max_input_length + max_decode_tokens

127
128
129
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
130
            requests_idx_mapping=requests_idx_mapping,
131
132
            input_ids=input_ids,
            attention_mask=attention_mask,
133
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
134
            past_key_values=None,
135
            all_input_ids=list(all_input_ids),
136
            input_lengths=input_lengths.tolist(),
137
138
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
139
140
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
141
            max_input_length=max_input_length.item(),
142
            padding_right_offset=padding_right_offset,
143
            max_tokens=max_tokens,
144
145
        )

146
    @tracer.start_as_current_span("filter")
147
148
    def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
        if len(request_ids) == 0:
149
            raise ValueError("Batch must have at least one request")
150
        if len(request_ids) == len(self):
151
152
153
154
155
156
            return self

        keep_indices = []

        # New values after filtering
        requests_idx_mapping = {}
157
        requests = []
158
        input_lengths = []
159
160
        prefix_offsets = []
        read_offsets = []
161
162
163
164
165
166
        all_input_ids = []
        max_input_length = 0

        next_token_choosers = []
        stopping_criterias = []

167
        total_remaining_decode_tokens = 0
168
169
        new_padding_right_offset = 0

170
171
172
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
            requests_idx_mapping[request_id] = i
173
174
            keep_indices.append(idx)

175
            requests.append(self.requests[idx])
176
177
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
178
179
180
181
182
183
184
            all_input_ids.append(self.all_input_ids[idx])

            request_input_length = self.input_lengths[idx]
            input_lengths.append(request_input_length)
            max_input_length = max(max_input_length, request_input_length)

            next_token_choosers.append(self.next_token_choosers[idx])
185
186
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
187
            remaining_decode_tokens = (
188
189
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
190
191
192
193
            total_remaining_decode_tokens += remaining_decode_tokens
            new_padding_right_offset = max(
                new_padding_right_offset, remaining_decode_tokens
            )
194
195
196
197

        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
        input_ids = self.input_ids[keep_indices]
        position_ids = self.position_ids[keep_indices]
198
199
        self.attention_mask = self.attention_mask[
            keep_indices,
200
201
202
203
            -(self.padding_right_offset + max_input_length) : (
                self.attention_mask.shape[1] - self.padding_right_offset
            )
            + new_padding_right_offset,
204
205
        ]

206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        # Ensure that past_key_values tensors can be updated in-place
        if type(self.past_key_values[0]) == tuple:
            self.past_key_values = [list(layer) for layer in self.past_key_values]

        # Update tensors in-place to allow incremental garbage collection
        past_kv_length = max_input_length - 1
        for layer in self.past_key_values:
            past_keys, past_values = layer
            if len(past_keys.shape) == 3:
                # Force past to be of dim [self_size, num_heads, ...] for easy indexing
                past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
                past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
            if self.keys_head_dim_last:
                layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
            else:
                layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
            del past_keys
            layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
            del past_values

226
        max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
227

228
229
230
231
232
233
        self.requests = requests
        self.requests_idx_mapping = requests_idx_mapping
        self.input_ids = input_ids
        self.position_ids = position_ids
        self.all_input_ids = all_input_ids
        self.input_lengths = input_lengths
234
235
        self.prefix_offsets = prefix_offsets
        self.read_offsets = read_offsets
236
237
238
239
        self.next_token_choosers = next_token_choosers
        self.stopping_criterias = stopping_criterias
        self.max_input_length = max_input_length
        self.padding_right_offset = new_padding_right_offset
240
        self.max_tokens = max_tokens
241
242

        return self
243

244
    @classmethod
245
    @tracer.start_as_current_span("concatenate")
246
247
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        # Used for padding
248
        total_batch_size = 0
249
        max_input_length = 0
250
251
        padding_right_offset = 0
        for batch in batches:
252
            total_batch_size += len(batch)
253
            max_input_length = max(max_input_length, batch.max_input_length)
254
            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
255
256
257

        # Batch attributes
        requests = []
258
        requests_idx_mapping = {}
OlivierDehaene's avatar
OlivierDehaene committed
259
        input_lengths = []
260
261
        prefix_offsets = []
        read_offsets = []
262
263
264
        all_input_ids = []
        next_token_choosers = []
        stopping_criterias = []
265
        max_tokens = 0
266

OlivierDehaene's avatar
OlivierDehaene committed
267
268
269
        # Batch tensors
        input_ids = None
        attention_mask = None
270
        position_ids = None
OlivierDehaene's avatar
OlivierDehaene committed
271
272
        past_key_values = []

273
274
275
276
277
        # Used for slicing correctly inside the tensors
        # Equivalent to a cumsum on batch sizes
        start_index = 0
        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
OlivierDehaene's avatar
OlivierDehaene committed
278
            input_lengths.extend(batch.input_lengths)
279
280
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
281
282
283
284
            all_input_ids.extend(batch.all_input_ids)
            next_token_choosers.extend(batch.next_token_choosers)
            stopping_criterias.extend(batch.stopping_criterias)

285
286
287
288
289
290
291
            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + start_index

292
            # Slicing end index for this batch
293
            end_index = start_index + len(batch)
294
295

            # We only concatenate batches that did at least one step
296
297
            if batch.past_key_values is None:
                raise ValueError("only concatenate prefilled batches")
298

OlivierDehaene's avatar
OlivierDehaene committed
299
300
301
302
            # Create empty tensor
            # input_ids is always of shape [batch_size, 1]
            # We do not need to pad it
            if input_ids is None:
303
                input_ids = batch.input_ids.new_empty((total_batch_size, 1))
OlivierDehaene's avatar
OlivierDehaene committed
304
305
306
307
308
            # Copy to correct indices
            input_ids[start_index:end_index] = batch.input_ids

            # Create padded tensor
            if attention_mask is None:
309
                attention_mask = batch.attention_mask.new_zeros(
310
                    (total_batch_size, max_input_length + padding_right_offset),
311
312
313
                )

            # We need to slice the attention mask to remove padding from previous steps
314
            # and to remove unused allocated space
315
            left_offset = max_input_length - batch.max_input_length
316
            batch_left_offset = (
317
                batch.attention_mask.shape[1]
318
                - batch.max_input_length
319
                - batch.padding_right_offset
320
            )
OlivierDehaene's avatar
OlivierDehaene committed
321
            attention_mask[
322
323
324
325
326
327
                start_index:end_index,
                left_offset:-padding_right_offset,
            ] = batch.attention_mask[
                :,
                batch_left_offset : -batch.padding_right_offset,
            ]
328

329
330
331
332
333
334
            # Create empty tensor
            # position_ids is always of shape [batch_size, 1]
            if position_ids is None:
                position_ids = batch.position_ids.new_empty((total_batch_size, 1))
            position_ids[start_index:end_index] = batch.position_ids

335
336
337
338
339
340
            # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
            # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]
            # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
            # And ensure that we can update tensors in-place
            if type(batch.past_key_values[0]) == tuple:
                batch.past_key_values = [
341
342
                    [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
                    for layer in batch.past_key_values
343
                ]
344
            elif len(batch.past_key_values[0][0].shape) == 3:
345
346
347
348
                for layer in batch.past_key_values:
                    for k, t in enumerate(layer):
                        layer[k] = t.view(len(batch), -1, *t.shape[-2:])

349
350
351
352
            # Add eventual padding tokens that were added while concatenating
            max_tokens += batch.max_tokens + (
                max_input_length - batch.max_input_length
            ) * len(batch)
353

354
355
            start_index = end_index

356
357
358
359
360
361
362
363
364
        first_past_kvs = batches[0].past_key_values
        _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape

        padded_past_values_shape = (
            total_batch_size,
            num_heads,
            max_input_length - 1,
            head_dim,
        )
365

366
367
368
369
370
371
372
373
374
375
        if batches[0].keys_head_dim_last:
            padded_past_keys_shape = padded_past_values_shape
        else:
            # seq_length is last for BLOOM
            padded_past_keys_shape = (
                total_batch_size,
                num_heads,
                head_dim,
                max_input_length - 1,
            )
376

377
378
379
380
381
382
383
384
385
386
387
388
389
390
        # Iterate over attention layers
        # Concatenate past key values layer by layer to allow incremental garbage collection
        for j in range(len(first_past_kvs)):
            padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
            start_index = 0
            for batch in batches:
                past_keys = batch.past_key_values[j][0]
                # Clear reference to the original tensor
                batch.past_key_values[j][0] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the keys to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
391
                if batch.keys_head_dim_last:
392
393
394
                    padded_past_keys[
                        start_index:end_index, :, -past_seq_len:, :
                    ] = past_keys[:, :, -past_seq_len:, :]
395
                else:
396
397
398
399
400
401
402
403
                    # BLOOM case
                    padded_past_keys[
                        start_index:end_index, :, :, -past_seq_len:
                    ] = past_keys[:, :, :, -past_seq_len:]
                del past_keys

                start_index = end_index

404
405
406
            padded_past_values = first_past_kvs[j][1].new_zeros(
                padded_past_values_shape
            )
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
            start_index = 0
            for batch in batches:
                past_values = batch.past_key_values[j][1]
                # Clear reference to the original tensor
                batch.past_key_values[j][1] = None

                # Slicing end index for this batch
                end_index = start_index + len(batch)
                # We slice the past values to remove the padding from previous batches
                past_seq_len = batch.max_input_length - 1
                padded_past_values[
                    start_index:end_index, :, -past_seq_len:, :
                ] = past_values[:, :, -past_seq_len:, :]
                del past_values

422
                # Update values
423
424
425
                start_index = end_index

            past_key_values.append([padded_past_keys, padded_past_values])
426
427
428
429

        return cls(
            batch_id=batches[0].batch_id,
            requests=requests,
430
            requests_idx_mapping=requests_idx_mapping,
431
            input_ids=input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
432
            attention_mask=attention_mask,
433
            position_ids=position_ids,
OlivierDehaene's avatar
OlivierDehaene committed
434
            past_key_values=past_key_values,
435
            all_input_ids=all_input_ids,
OlivierDehaene's avatar
OlivierDehaene committed
436
            input_lengths=input_lengths,
437
438
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
439
440
            next_token_choosers=next_token_choosers,
            stopping_criterias=stopping_criterias,
441
            max_input_length=max_input_length,
442
            padding_right_offset=padding_right_offset,
443
            keys_head_dim_last=batches[0].keys_head_dim_last,
444
            max_tokens=max_tokens,
445
        )
446

447
448
449
    def __len__(self):
        return len(self.requests)

450
451

class CausalLM(Model):
452
453
454
455
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
456
        quantize: Optional[str] = None,
457
        trust_remote_code: bool = False,
458
    ):
459
460
        if torch.cuda.is_available():
            device = torch.device("cuda")
461
            dtype = torch.float16
462
        else:
463
464
465
            if quantize:
                raise ValueError("quantization is not available on CPU")

466
467
468
            device = torch.device("cpu")
            dtype = torch.float32

469
        tokenizer = AutoTokenizer.from_pretrained(
470
471
472
473
474
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
475
        )
476
        model = AutoModelForCausalLM.from_pretrained(
477
            model_id,
478
            revision=revision,
479
            torch_dtype=dtype,
480
481
482
            device_map="auto"
            if torch.cuda.is_available() and torch.cuda.device_count() > 1
            else None,
483
            load_in_8bit=quantize == "bitsandbytes",
484
            trust_remote_code=trust_remote_code,
485
        )
486
487
488
        if torch.cuda.is_available() and torch.cuda.device_count() == 1:
            model = model.cuda()

489
490
491
492
493
494
495
496
497
498
499
500
501
        if tokenizer.pad_token_id is None:
            if model.config.pad_token_id is not None:
                tokenizer.pad_token_id = model.config.pad_token_id
            elif model.config.eos_token_id is not None:
                tokenizer.pad_token_id = model.config.eos_token_id
            elif tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
            else:
                tokenizer.add_special_tokens({"pad_token": "[PAD]"})

        self.has_position_ids = (
            inspect.signature(model.forward).parameters.get("position_ids", None)
            is not None
502
        )
503

504
        super(CausalLM, self).__init__(
505
            model=model,
506
507
508
509
            tokenizer=tokenizer,
            requires_padding=True,
            dtype=dtype,
            device=device,
510
511
512
513
514
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return CausalLMBatch
515

516
517
    def decode(self, generated_ids: List[int]) -> str:
        return self.tokenizer.decode(
518
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
519
520
        )

521
    def forward(
522
        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
523
524
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        # Model Forward
525
526
527
528
529
530
531
532
533
534
535
        kwargs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "past_key_values": past_key_values,
            "use_cache": True,
            "return_dict": True,
        }
        if self.has_position_ids:
            kwargs["position_ids"] = position_ids

        outputs = self.model.forward(**kwargs)
536
        return outputs.logits, outputs.past_key_values
537

538
    @tracer.start_as_current_span("generate_token")
539
540
    def generate_token(
        self, batch: CausalLMBatch
541
    ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
542
543
544
        # slice the attention mask to the correct shape
        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

545
546
        logits, past = self.forward(
            batch.input_ids,
547
            attention_mask,
548
549
            batch.position_ids,
            batch.past_key_values,
550
551
        )

552
553
        # Results
        generations: List[Generation] = []
554
        stopped = True
555
556
557
558

        # Zipped iterator
        iterator = zip(
            batch.requests,
OlivierDehaene's avatar
OlivierDehaene committed
559
            batch.input_lengths,
560
561
            batch.prefix_offsets,
            batch.read_offsets,
562
563
564
565
566
567
568
569
570
571
            logits,
            batch.next_token_choosers,
            batch.stopping_criterias,
            batch.all_input_ids,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
572
573
            prefix_offset,
            read_offset,
574
575
576
            logits,
            next_token_chooser,
            stopping_criteria,
OlivierDehaene's avatar
OlivierDehaene committed
577
            all_input_ids,
578
579
        ) in enumerate(iterator):
            # Select next token
580
            next_token_id, logprobs = next_token_chooser(
581
                all_input_ids.view(1, -1), logits[-1:, :]
582
            )
583
584

            # Append next token to all tokens
585
            all_input_ids = torch.cat([all_input_ids, next_token_id])
OlivierDehaene's avatar
OlivierDehaene committed
586
587
            new_input_length = input_length + 1

588
589
590
            # Generated token
            next_token_logprob = logprobs[-1, next_token_id]
            next_token_id_squeezed = next_token_id.squeeze()
591
592
            next_token_text, prefix_offset, read_offset = self.decode_token(
                all_input_ids[:, 0], prefix_offset, read_offset
593
            )
594
595

            # Evaluate stopping criteria
596
            stop, reason = stopping_criteria(
597
598
                next_token_id_squeezed,
                next_token_text,
599
            )
600

601
            if not stop:
602
                stopped = False
603

604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
                    output_text = self.decode(
                        all_input_ids[-stopping_criteria.current_tokens :, 0]
                    )
                    # Get seed
                    if isinstance(next_token_chooser.choice, Sampling):
                        seed = next_token_chooser.choice.seed
                    else:
                        seed = None

                    generated_text = GeneratedText(
                        output_text, stopping_criteria.current_tokens, reason, seed
                    )
                else:
                    generated_text = None

                # Prefill
                if stopping_criteria.current_tokens == 1:
                    # Remove generated token to only have prefill and add nan for first prompt token
                    prefill_logprobs = [float("nan")] + torch.log_softmax(
                        logits, -1
                    ).gather(1, all_input_ids[1:]).squeeze(1)[
                        -new_input_length:-1
                    ].tolist()
                    prefill_token_ids = all_input_ids[-new_input_length:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    prefill_tokens = PrefillTokens(
                        prefill_token_ids, prefill_logprobs, prefill_texts
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id_squeezed,
                    next_token_logprob,
                    next_token_text,
                    next_token_id_squeezed.item() in self.all_special_ids,
                    generated_text,
652
653
                )

654
                generations.append(generation)
655

656
657
658
659
            # Update values
            batch.input_ids[i, 0] = next_token_id
            batch.all_input_ids[i] = all_input_ids
            batch.input_lengths[i] = new_input_length
660
661
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
662
663
            batch.max_input_length = max(batch.max_input_length, new_input_length)

664
        # We finished all generations in the batch; there is no next batch
665
        if stopped:
666
            return generations, None
667

668
669
        # Slice unused values from prefill
        batch.input_ids = batch.input_ids[:, :1]
670

671
        # Update attention_mask as we added a new token to input_ids
672
673
674
        batch.attention_mask[:, -batch.padding_right_offset] = 1
        # Decrease right offset
        batch.padding_right_offset -= 1
675

676
        # Update position_ids
677
678
679
680
681
682
        batch.position_ids = batch.position_ids[:, -1:] + 1

        # Update past key values
        batch.past_key_values = past

        return generations, batch