generation_utils.py 175 KB
Newer Older
1
# coding=utf-8
2
3
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

17
import inspect
18
import warnings
19
20
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
21
22

import torch
23
import torch.distributed as dist
24
from torch import nn
25

26
from .generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
27
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
28
from .generation_logits_process import (
29
    EncoderNoRepeatNGramLogitsProcessor,
30
    ExponentialDecayLengthPenalty,
31
32
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
33
    HammingDiversityLogitsProcessor,
34
    InfNanRemoveLogitsProcessor,
35
    LogitNormalization,
36
37
38
39
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    NoBadWordsLogitsProcessor,
    NoRepeatNGramLogitsProcessor,
40
    PrefixConstrainedLogitsProcessor,
41
42
43
44
    RepetitionPenaltyLogitsProcessor,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
45
    TypicalLogitsWarper,
46
)
47
48
49
from .generation_stopping_criteria import (
    MaxLengthCriteria,
    MaxTimeCriteria,
50
    StoppingCriteria,
51
52
53
    StoppingCriteriaList,
    validate_stopping_criteria,
)
54
from .pytorch_utils import torch_int_div
55
from .utils import ModelOutput, logging
56

Lysandre Debut's avatar
Lysandre Debut committed
57
58

logger = logging.get_logger(__name__)
59
60


61
62
63
64
65
66
67
@dataclass
class GreedySearchDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using greedy search.


    Args:
68
        sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
69
70
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
71
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
72
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
Sylvain Gugger's avatar
Sylvain Gugger committed
73
74
            at each generation step. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each
            tensor of shape `(batch_size, config.vocab_size)`).
75
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
76
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
77
78
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
79
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
80
            `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class GreedySearchEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention
    weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
    encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)


    Args:
98
        sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
99
100
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
101
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
102
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
Sylvain Gugger's avatar
Sylvain Gugger committed
103
104
            at each generation step. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
            `(batch_size, config.vocab_size)`).
105
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
106
107
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
            sequence_length, sequence_length)`.
108
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
109
110
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
111
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
112
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
113
114
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
115
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
116
117
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
118
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
119
            `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
120
121
122
123
124
125
126
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
127
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
128
129
130
131
132
133
134
135
136
137
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class SampleDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using sampling.


    Args:
138
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
139
140
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
141
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
142
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
Sylvain Gugger's avatar
Sylvain Gugger committed
143
144
            at each generation step. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each
            tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`).
145
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
146
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
Sylvain Gugger's avatar
Sylvain Gugger committed
147
148
            `torch.FloatTensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length,
            sequence_length)`.
149
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
150
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
151
            `torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`.
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class SampleEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of
    the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
    attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)


    Args:
169
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
170
171
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
172
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
173
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
Sylvain Gugger's avatar
Sylvain Gugger committed
174
175
            at each generation step. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
            `(batch_size*num_return_sequences, config.vocab_size)`).
176
177
178
179
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape
            `(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`.
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
180
181
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size*num_return_sequences, sequence_length, hidden_size)`.
182
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
183
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
Sylvain Gugger's avatar
Sylvain Gugger committed
184
185
            `torch.FloatTensor` of shape `(batch_size*num_return_sequences, num_heads, generated_length,
            sequence_length)`.
186
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
187
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
188
189
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
190
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
191
            `torch.FloatTensor` of shape `(batch_size*num_return_sequences, generated_length, hidden_size)`.
192
193
194
195
196
197
198
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
199
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
200
201
202
203
204
205
206
207
208
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSearchDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using beam search.

    Args:
209
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
210
211
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
212
213
214
        sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
215
216
217
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
            `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
Sylvain Gugger's avatar
Sylvain Gugger committed
218
            `(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
219
220
221
        beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
            tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
222
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
223
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
224
225
            `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
226
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
227
            `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
228
229
230
231
232
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
233
    beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
234
235
236
237
238
239
240
241
242
243
244
245
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSearchEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights
    of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
    attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)

    Args:
246
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
247
248
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
249
250
251
        sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
252
253
254
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
            `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
Sylvain Gugger's avatar
Sylvain Gugger committed
255
            config.vocab_size)`).
256
257
258
        beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
            tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
259
260
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
261
262
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
            sequence_length, sequence_length)`.
263
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
264
265
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
266
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
267
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
Sylvain Gugger's avatar
Sylvain Gugger committed
268
269
            `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
            sequence_length)`.
270
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
271
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
272
273
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
274
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
275
            `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
276
277
278
279
280
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
281
    beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
282
283
284
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
285
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
286
287
288
289
290
291
292
293
294
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSampleDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using beam sample.

    Args:
295
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
296
297
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
298
299
300
        sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
301
302
303
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
            `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
Sylvain Gugger's avatar
Sylvain Gugger committed
304
            `(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
305
306
307
        beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
            tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
308
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
309
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
310
311
            `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
312
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
313
            `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
314
315
316
317
318
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
319
    beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
320
321
322
323
324
325
326
327
328
329
330
331
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSampleEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention
    weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
    encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)

    Args:
332
        sequences (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
333
334
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
335
336
337
        sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
338
339
340
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
            `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
Sylvain Gugger's avatar
Sylvain Gugger committed
341
            config.vocab_size)`).
342
343
344
        beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
            tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
345
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
346
347
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
            sequence_length, sequence_length)`.
348
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
349
350
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size*num_beams, sequence_length, hidden_size)`.
351
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
352
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
353
354
            `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
355
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
356
357
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
358
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
359
            `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
360
361
362
363
364
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
365
    beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
366
367
368
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
369
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
370
371
372
373
374
375
376
377
378
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]


379
380
class GenerationMixin:
    """
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].

    The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for:
        - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
          `do_sample=False`.
        - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
          `do_sample=True`.
        - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
          `do_sample=False`.
        - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
          `num_beams>1` and `do_sample=True`.
        - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
          `num_beams>1` and `num_beam_groups>1`.
        - *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`],
          if `constraints!=None` or `force_words_ids!=None`.
396
397
    """

398
399
400
401
402
    def _prepare_model_inputs(
        self,
        inputs: Optional[torch.Tensor] = None,
        bos_token_id: Optional[int] = None,
        model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
403
    ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
404
405
406
        """
        This function extracts the model-specific `inputs` for generation.
        """
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        # 1. retrieve all kwargs that are non-None or non-model input related.
        # some encoder-decoder models have different names for model and encoder
        if (
            self.config.is_encoder_decoder
            and hasattr(self, "encoder")
            and self.encoder.main_input_name != self.main_input_name
        ):
            input_name = self.encoder.main_input_name
        else:
            input_name = self.main_input_name

        model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}

        # 2. check whether model_input_name is passed as kwarg
        # if yes and `inputs` is None use kwarg inputs
        inputs_kwarg = model_kwargs.pop(input_name, None)
        if inputs_kwarg is not None and inputs is not None:
424
425
            raise ValueError(
                f"`inputs`: {inputs}` were passed alongside "
426
427
                f"{input_name} which is not allowed."
                f"Make sure to either pass {inputs} or {input_name}=..."
428
            )
429
430
431
432
433
434
435
436
437
        elif inputs_kwarg is not None:
            inputs = inputs_kwarg

        # 3. models with `input_ids` can also make use of `inputs_embeds`
        if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs):
            inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"

        # 4. Only encoder-decoder models can have non `input_ids` input format
        if not self.config.is_encoder_decoder and input_name != "input_ids":
438
            raise ValueError(
439
440
441
                f"If {input_name} is passed as model-specific keyword "
                "input then model has to be an encoder-decoder and not a "
                f"{self.__class__.__name__}."
442
443
            )

444
445
446
447
448
449
450
451
452
453
        # 5. if `inputs` is still None, try to create `input_ids` from BOS token
        if inputs is None:
            inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))

        return inputs, input_name, model_kwargs

    def _can_retrieve_inputs_from_name(
        self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor]
    ) -> torch.Tensor:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
454
455
        If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved
        from name
456
457
458
459
460
461
462
463
464
465
        """
        can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set(
            inspect.signature(self.forward).parameters.keys()
        )

        if can_retrieve_inputs and inputs is not None:
            raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}")

        return can_retrieve_inputs

466
    def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
Sylvain Gugger's avatar
Sylvain Gugger committed
467
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
468
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
Sylvain Gugger's avatar
Sylvain Gugger committed
469
        """
470
471
        return {"input_ids": input_ids}

472
    def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
Sylvain Gugger's avatar
Sylvain Gugger committed
473
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
474
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
Sylvain Gugger's avatar
Sylvain Gugger committed
475
        """
476
477
        return logits

478
479
480
481
482
483
484
485
    def _prepare_input_ids_for_generation(
        self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput]
    ) -> torch.LongTensor:
        if self.config.is_encoder_decoder and encoder_outputs is not None:
            # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
            shape = encoder_outputs.last_hidden_state.size()[:-1]
            return torch.ones(shape, dtype=torch.long, device=self.device) * -100

486
487
488
489
490
        if bos_token_id is None:
            raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
        return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id

    def _prepare_attention_mask_for_generation(
491
        self,
492
        inputs: torch.Tensor,
493
494
        pad_token_id: int,
        eos_token_id: int,
495
    ) -> torch.LongTensor:
496
        is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
497
        is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
498
499
500
        is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
            (eos_token_id is not None) and (pad_token_id != eos_token_id)
        )
501
502
503
        # Check if input is input_ids and padded -> only then is attention_mask defined
        if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
            return inputs.ne(pad_token_id).long()
504
        else:
505
            return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
506
507

    def _prepare_encoder_decoder_kwargs_for_generation(
508
        self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
509
    ) -> Dict[str, Any]:
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
        # 1. get encoder
        encoder = self.get_encoder()

        # 2. prepare encoder args and encoder kwargs from model kwargs
        irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
        encoder_kwargs = {
            argument: value
            for argument, value in model_kwargs.items()
            if not any(argument.startswith(p) for p in irrelevant_prefix)
        }

        # 3. make sure that encoder returns `ModelOutput`
        model_input_name = model_input_name if model_input_name is not None else self.main_input_name
        encoder_kwargs["return_dict"] = True
        encoder_kwargs[model_input_name] = inputs_tensor
        model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
526

527
528
529
        return model_kwargs

    def _prepare_decoder_input_ids_for_generation(
530
531
532
533
534
        self,
        batch_size: int,
        decoder_start_token_id: int = None,
        bos_token_id: int = None,
        model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
535
        device: torch.device = None,
536
    ) -> torch.LongTensor:
537

538
539
540
541
        if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
            return model_kwargs.pop("decoder_input_ids")
        else:
            decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
542
543
544
            if device is None:
                device = self.device
            return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576

    def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
        decoder_start_token_id = (
            decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
        )
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id

        if decoder_start_token_id is not None:
            return decoder_start_token_id
        elif (
            hasattr(self.config, "decoder")
            and hasattr(self.config.decoder, "decoder_start_token_id")
            and self.config.decoder.decoder_start_token_id is not None
        ):
            return self.config.decoder.decoder_start_token_id
        elif bos_token_id is not None:
            return bos_token_id
        elif (
            hasattr(self.config, "decoder")
            and hasattr(self.config.decoder, "bos_token_id")
            and self.config.decoder.bos_token_id is not None
        ):
            return self.config.decoder.bos_token_id
        raise ValueError(
            "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
        )

    @staticmethod
    def _expand_inputs_for_generation(
        input_ids: torch.LongTensor,
        expand_size: int = 1,
        is_encoder_decoder: bool = False,
577
578
        attention_mask: Optional[torch.LongTensor] = None,
        encoder_outputs: Optional[ModelOutput] = None,
579
        **model_kwargs,
580
581
582
583
584
585
    ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
        expanded_return_idx = (
            torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
        )
        input_ids = input_ids.index_select(0, expanded_return_idx)

586
587
588
589
        if "token_type_ids" in model_kwargs:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)

590
591
592
593
        if attention_mask is not None:
            model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)

        if is_encoder_decoder:
594
595
            if encoder_outputs is None:
                raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
596
            encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
597
                0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
            )
            model_kwargs["encoder_outputs"] = encoder_outputs
        return input_ids, model_kwargs

    @staticmethod
    def _update_model_kwargs_for_generation(
        outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
    ) -> Dict[str, Any]:
        # update past
        if "past_key_values" in outputs:
            model_kwargs["past"] = outputs.past_key_values
        elif "mems" in outputs:
            model_kwargs["past"] = outputs.mems
        elif "past_buckets_states" in outputs:
            model_kwargs["past"] = outputs.past_buckets_states
        else:
            model_kwargs["past"] = None

616
617
618
619
620
        # update token_type_ids with last value
        if "token_type_ids" in model_kwargs:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)

621
622
623
624
625
626
627
628
629
630
        # update attention mask
        if not is_encoder_decoder:
            if "attention_mask" in model_kwargs:
                attention_mask = model_kwargs["attention_mask"]
                model_kwargs["attention_mask"] = torch.cat(
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
                )

        return model_kwargs

631
632
    def _reorder_cache(self, past, beam_idx):
        raise NotImplementedError(
Sylvain Gugger's avatar
Sylvain Gugger committed
633
634
            f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to"
            f" enable beam search for {self.__class__}"
635
        )
636

637
    def _get_logits_warper(
638
        self,
639
640
641
642
643
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        typical_p: Optional[float] = None,
        temperature: Optional[float] = None,
        num_beams: Optional[int] = None,
644
        renormalize_logits: Optional[bool] = None,
645
646
    ) -> LogitsProcessorList:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
647
648
        This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
        used for multinomial sampling.
649
        """
650

651
652
653
        # init warp parameters
        top_k = top_k if top_k is not None else self.config.top_k
        top_p = top_p if top_p is not None else self.config.top_p
654
        typical_p = typical_p if typical_p is not None else self.config.typical_p
655
656
657
658
659
660
        temperature = temperature if temperature is not None else self.config.temperature
        # instantiate warpers list
        warpers = LogitsProcessorList()

        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
661
662
        if temperature is not None and temperature != 1.0:
            warpers.append(TemperatureLogitsWarper(temperature))
663
664
665
666
        if top_k is not None and top_k != 0:
            warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
        if top_p is not None and top_p < 1.0:
            warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
667
668
        if typical_p is not None and typical_p < 1.0:
            warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
669
670
671
        # `LogitNormalization` should always be the last logit processor, when present
        if renormalize_logits is True:
            warpers.append(LogitNormalization())
672
673
674
675
676
677
        return warpers

    def _get_logits_processor(
        self,
        repetition_penalty: float,
        no_repeat_ngram_size: int,
678
        encoder_no_repeat_ngram_size: int,
679
        input_ids_seq_length: int,
680
        encoder_input_ids: torch.LongTensor,
681
682
        bad_words_ids: List[List[int]],
        min_length: int,
683
        max_length: int,
684
        eos_token_id: int,
685
686
        forced_bos_token_id: int,
        forced_eos_token_id: int,
687
688
        prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
        num_beams: int,
689
690
        num_beam_groups: int,
        diversity_penalty: float,
691
        remove_invalid_values: bool,
692
        exponential_decay_length_penalty: Tuple,
693
        logits_processor: Optional[LogitsProcessorList],
694
        renormalize_logits: Optional[bool],
695
696
    ) -> LogitsProcessorList:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
697
698
        This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
        instances used to modify the scores of the language model head.
699
        """
700
        processors = LogitsProcessorList()
701

702
703
704
705
706
        # init warp parameters
        repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
        no_repeat_ngram_size = (
            no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
        )
707
708
709
710
711
        encoder_no_repeat_ngram_size = (
            encoder_no_repeat_ngram_size
            if encoder_no_repeat_ngram_size is not None
            else self.config.encoder_no_repeat_ngram_size
        )
712
713
        bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
714
        diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
715
716
717
718
719
720
        forced_bos_token_id = (
            forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
        )
        forced_eos_token_id = (
            forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
        )
721
722
723
        remove_invalid_values = (
            remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
        )
724
725
726
727
728
        exponential_decay_length_penalty = (
            exponential_decay_length_penalty
            if exponential_decay_length_penalty is not None
            else self.config.exponential_decay_length_penalty
        )
729
730
731
732
        # instantiate processors list

        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
733
734
735
736
737
738
        if diversity_penalty is not None and diversity_penalty > 0.0:
            processors.append(
                HammingDiversityLogitsProcessor(
                    diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups
                )
            )
739
740
741
742
        if repetition_penalty is not None and repetition_penalty != 1.0:
            processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
        if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
            processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
743
744
745
746
747
748
749
        if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0:
            if self.config.is_encoder_decoder:
                processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids))
            else:
                raise ValueError(
                    "It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture"
                )
750
        if bad_words_ids is not None:
751
            processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
752
        if min_length is not None and eos_token_id is not None and min_length > 0:
753
            processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
754
        if prefix_allowed_tokens_fn is not None:
755
            processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups))
756
757
758
759
        if forced_bos_token_id is not None:
            processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
        if forced_eos_token_id is not None:
            processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
760
761
        if remove_invalid_values is True:
            processors.append(InfNanRemoveLogitsProcessor())
762
763
764
765
        if exponential_decay_length_penalty is not None:
            processors.append(
                ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
            )
766
        processors = self._merge_criteria_processor_list(processors, logits_processor)
767
768
769
        # `LogitNormalization` should always be the last logit processor, when present
        if renormalize_logits is True:
            processors.append(LogitNormalization())
770
        return processors
771

772
773
774
775
    def _get_stopping_criteria(
        self, max_length: Optional[int], max_time: Optional[float], stopping_criteria: Optional[StoppingCriteriaList]
    ) -> StoppingCriteriaList:
        criteria = StoppingCriteriaList()
776
        if max_length is not None:
777
            criteria.append(MaxLengthCriteria(max_length=max_length))
778
        if max_time is not None:
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
            criteria.append(MaxTimeCriteria(max_time=max_time))
        criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
        return criteria

    def _merge_criteria_processor_list(
        self,
        default_list: Union[LogitsProcessorList, StoppingCriteriaList],
        custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
    ) -> Union[LogitsProcessorList, StoppingCriteriaList]:
        if len(custom_list) == 0:
            return default_list
        for default in default_list:
            for custom in custom_list:
                if type(custom) is type(default):
                    object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
                    raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
795
796
797
798
799
                        f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
                        f" `generate`, but it has already been created with the values {default}. {default} has been"
                        " created by passing the corresponding arguments to generate or by the model's config default"
                        f" values. If you just want to change the default values of {object_type} consider passing"
                        f" them as arguments to `generate` instead of using a custom {object_type}."
800
801
802
                    )
        default_list.extend(custom_list)
        return default_list
803

804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
    def compute_transition_beam_scores(
        self,
        sequences: torch.Tensor,
        scores: Tuple[torch.Tensor],
        beam_indices: torch.Tensor,
        eos_token_id: int = None,
    ):
        """compute the transition probabilities of sequences given generation
        scores and beam indices"""

        # reshape scores as [vocab_size * batch_size, # generation steps]
        # with batch_size being 2 * vocab_size and # generation steps being
        # seq_len - input_length
        scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)

        # start of generated tokens
        cut_idx = sequences.shape[-1] - scores.shape[-1]
        # adjust for beam indices
        beam_sequence_indices = torch.tensor(beam_indices, device=sequences.device) * self.config.vocab_size
        # compute real indices
        indices = sequences[:, cut_idx:] + beam_sequence_indices
        # gather scores and run
        transition_scores = scores.gather(0, indices)
        # make sure that if EOS token was used before length of sequence `sequence.shape[-1]`
        # get first occurence of EOS token
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id

        if eos_token_id is not None:
            is_eos_token_id = sequences[:, cut_idx:] == eos_token_id
            # make sure first eos token still contributes to transition probs
            is_eos_token_id[:, -1] = False
            is_eos_token_id = is_eos_token_id.roll(1, -1)
            # all indices after eos shoud be masked
            zero_transition_prob_mask = is_eos_token_id.cumsum(-1).bool()
            # zero out padded probs
            transition_scores.masked_fill_(zero_transition_prob_mask, 0.0)

        return transition_scores

843
844
845
    @torch.no_grad()
    def generate(
        self,
846
        inputs: Optional[torch.Tensor] = None,
847
848
849
850
851
852
853
854
        max_length: Optional[int] = None,
        min_length: Optional[int] = None,
        do_sample: Optional[bool] = None,
        early_stopping: Optional[bool] = None,
        num_beams: Optional[int] = None,
        temperature: Optional[float] = None,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
855
        typical_p: Optional[float] = None,
856
857
        repetition_penalty: Optional[float] = None,
        bad_words_ids: Optional[Iterable[int]] = None,
858
        force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
859
860
861
862
863
        bos_token_id: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        no_repeat_ngram_size: Optional[int] = None,
864
        encoder_no_repeat_ngram_size: Optional[int] = None,
865
        num_return_sequences: Optional[int] = None,
866
        max_time: Optional[float] = None,
867
        max_new_tokens: Optional[int] = None,
868
869
        decoder_start_token_id: Optional[int] = None,
        use_cache: Optional[bool] = None,
870
871
        num_beam_groups: Optional[int] = None,
        diversity_penalty: Optional[float] = None,
872
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
873
        logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
874
        renormalize_logits: Optional[bool] = None,
875
        stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
876
        constraints: Optional[List[Constraint]] = None,
877
878
879
880
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
881
882
        forced_bos_token_id: Optional[int] = None,
        forced_eos_token_id: Optional[int] = None,
883
        remove_invalid_values: Optional[bool] = None,
884
        synced_gpus: Optional[bool] = False,
885
        exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
886
887
        **model_kwargs,
    ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
Sylvain Gugger's avatar
Sylvain Gugger committed
888
        r"""
889

890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
        Generates sequences of token ids for models with a language modeling head. The method supports the following
        generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:

            - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
              `do_sample=False`.
            - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
              `do_sample=True`.
            - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
              `do_sample=False`.
            - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
              `num_beams>1` and `do_sample=True`.
            - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
              `num_beams>1` and `num_beam_groups>1`.
            - *constrained beam-search decoding* by calling
              [`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or
              `force_words_ids!=None`.

        <Tip warning={true}>

        Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
        defined in the model's config (`config.json`) which in turn defaults to the
        [`~modeling_utils.PretrainedConfig`] of the model.

        </Tip>
914

Sylvain Gugger's avatar
Sylvain Gugger committed
915
916
        Most of these parameters are explained in more detail in [this blog
        post](https://huggingface.co/blog/how-to-generate).
917
918

        Parameters:
919
            inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
920
                The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
Sylvain Gugger's avatar
Sylvain Gugger committed
921
922
923
                method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
                should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
                `input_ids`, `input_values`, `input_features`, or `pixel_values`.
924
            max_length (`int`, *optional*, defaults to `model.config.max_length`):
Sylvain Gugger's avatar
Sylvain Gugger committed
925
                The maximum length of the sequence to be generated.
926
            max_new_tokens (`int`, *optional*, defaults to None):
927
                The maximum numbers of tokens to generate, ignore the current number of tokens. Use either
928
929
                `max_new_tokens` or `max_length` but not both, they serve the same purpose.
            min_length (`int`, *optional*, defaults to 10):
Sylvain Gugger's avatar
Sylvain Gugger committed
930
                The minimum length of the sequence to be generated.
931
            do_sample (`bool`, *optional*, defaults to `False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
932
                Whether or not to use sampling ; use greedy decoding otherwise.
933
934
935
            early_stopping (`bool`, *optional*, defaults to `False`):
                Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
            num_beams (`int`, *optional*, defaults to 1):
Sylvain Gugger's avatar
Sylvain Gugger committed
936
                Number of beams for beam search. 1 means no beam search.
937
            temperature (`float`, *optional*, defaults to 1.0):
Sylvain Gugger's avatar
Sylvain Gugger committed
938
                The value used to module the next token probabilities.
939
            top_k (`int`, *optional*, defaults to 50):
Sylvain Gugger's avatar
Sylvain Gugger committed
940
                The number of highest probability vocabulary tokens to keep for top-k-filtering.
941
            top_p (`float`, *optional*, defaults to 1.0):
Sylvain Gugger's avatar
Sylvain Gugger committed
942
943
                If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
                are kept for generation.
944
            typical_p (`float`, *optional*, defaults to 1.0):
Yih-Dar's avatar
Yih-Dar committed
945
946
                The amount of probability mass from the original distribution to be considered in typical decoding. If
                set to 1.0 it takes no effect. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
947
            repetition_penalty (`float`, *optional*, defaults to 1.0):
Sylvain Gugger's avatar
Sylvain Gugger committed
948
949
                The parameter for repetition penalty. 1.0 means no penalty. See [this
                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
950
951
952
953
954
955
956
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            bos_token_id (`int`, *optional*):
                The id of the *beginning-of-sequence* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
            length_penalty (`float`, *optional*, defaults to 1.0):
Sylvain Gugger's avatar
Sylvain Gugger committed
957
958
959
                 Exponential penalty to the length. 1.0 means that the beam score is penalized by the sequence length.
                 0.0 means no penalty. Set to values < 0.0 in order to encourage the model to generate longer
                 sequences, to a value > 0.0 in order to encourage the model to produce shorter sequences.
960
            no_repeat_ngram_size (`int`, *optional*, defaults to 0):
Sylvain Gugger's avatar
Sylvain Gugger committed
961
                If set to int > 0, all ngrams of that size can only occur once.
962
963
964
965
            encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0):
                If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
                `decoder_input_ids`.
            bad_words_ids(`List[List[int]]`, *optional*):
966
967
                List of token ids that are not allowed to be generated. In order to get the token ids of the words that
                should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
968
                add_special_tokens=False).input_ids`.
969
970
971
972
973
            force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
                List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple
                list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`,
                this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081),
                where one can allow different forms of each word.
974
            num_return_sequences(`int`, *optional*, defaults to 1):
Sylvain Gugger's avatar
Sylvain Gugger committed
975
                The number of independently computed returned sequences for each element in the batch.
976
            max_time(`float`, *optional*, defaults to None):
977
978
                The maximum amount of time you allow the computation to run for in seconds. generation will still
                finish the current pass after allocated time has been passed.
979
            attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
980
981
982
                Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens
                that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape
                as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask)
983
984
985
            decoder_start_token_id (`int`, *optional*):
                If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
            use_cache: (`bool`, *optional*, defaults to `True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
986
987
                Whether or not the model should use the past last key/values attentions (if applicable to the model) to
                speed up decoding.
988
989
990
991
            num_beam_groups (`int`, *optional*, defaults to 1):
                Number of groups to divide `num_beams` into in order to ensure diversity among different groups of
                beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
            diversity_penalty (`float`, *optional*, defaults to 0.0):
992
                This value is subtracted from a beam's score if it generates a token same as any beam from other group
993
                at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is
994
                enabled.
995
            prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
996
                If provided, this function constraints the beam search to allowed tokens only at each step. If not
997
                provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
Sylvain Gugger's avatar
Sylvain Gugger committed
998
999
1000
1001
                `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
                on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
                for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
                Retrieval](https://arxiv.org/abs/2010.00904).
1002
            logits_processor (`LogitsProcessorList`, *optional*):
1003
1004
1005
                 Custom logits processors that complement the default logits processors built from arguments and a
                 model's config. If a logit processor is passed that is already created with the arguments or a model's
                 config an error is thrown. This feature is intended for advanced users.
1006
1007
1008
1009
            renormalize_logits: (`bool`, *optional*, defaults to `False`):
                Whether to renormalize the logits after applying all the logits processors or warpers (including the
                custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
                score logits are normalized but some logit processors or warpers break the normalization.
1010
            stopping_criteria (`StoppingCriteriaList`, *optional*):
1011
1012
1013
                 Custom stopping criteria that complement the default stopping criteria built from arguments and a
                 model's config. If a stopping criteria is passed that is already created with the arguments or a
                 model's config an error is thrown. This feature is intended for advanced users.
1014
1015
1016
            constraints (`List[Constraint]`, *optional*):
                 Custom constraints that can be added to the generation to ensure that the output will contain the use
                 of certain tokens as defined by `Constraint` objects, in the most sensible way possible.
1017
            output_attentions (`bool`, *optional*, defaults to `False`):
1018
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1019
                returned tensors for more details.
1020
            output_hidden_states (`bool`, *optional*, defaults to `False`):
1021
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1022
                for more details.
1023
            output_scores (`bool`, *optional*, defaults to `False`):
1024
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
1025
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
1026
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1027
            forced_bos_token_id (`int`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1028
1029
1030
                The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful
                for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be
                the target language token.
1031
1032
1033
1034
1035
1036
            forced_eos_token_id (`int`, *optional*):
                The id of the token to force as the last generated token when `max_length` is reached.
            remove_invalid_values (`bool`, *optional*):
                Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to
                crash. Note that using `remove_invalid_values` can slow down generation.
            synced_gpus (`bool`, *optional*, defaults to `False`):
1037
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
1038
1039
1040
1041
1042
            exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
                This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
                generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates
                where penalty starts and `decay_factor` represents the factor of exponential decay

1043
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
1044
1045
1046
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
                is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
                should be prefixed with *decoder_*.
1047
1048

        Return:
1049
1050
            [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
            or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
1051

Sylvain Gugger's avatar
Sylvain Gugger committed
1052
                If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
1053
                [`~utils.ModelOutput`] types are:
1054
1055
1056
1057
1058
1059
1060

                    - [`~generation_utils.GreedySearchDecoderOnlyOutput`],
                    - [`~generation_utils.SampleDecoderOnlyOutput`],
                    - [`~generation_utils.BeamSearchDecoderOnlyOutput`],
                    - [`~generation_utils.BeamSampleDecoderOnlyOutput`]

                If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
1061
                [`~utils.ModelOutput`] types are:
1062
1063
1064
1065
1066
1067
1068
1069

                    - [`~generation_utils.GreedySearchEncoderDecoderOutput`],
                    - [`~generation_utils.SampleEncoderDecoderOutput`],
                    - [`~generation_utils.BeamSearchEncoderDecoderOutput`],
                    - [`~generation_utils.BeamSampleEncoderDecoderOutput`]

        Examples:

1070
1071
        Greedy Decoding:

1072
        ```python
1073
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM
1074

1075
1076
        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
1077

1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
        >>> prompt = "Today I believe we can finally"
        >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids

        >>> # generate up to 30 tokens
        >>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
        ```

        Multinomial Sampling:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
1094
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119

        >>> prompt = "Today I believe we can finally"
        >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids

        >>> # sample up to 30 tokens
        >>> torch.manual_seed(0)  # doctest: +IGNORE_RESULT
        >>> outputs = model.generate(input_ids, do_sample=True, max_length=30)
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the']
        ```

        Beam-search decoding:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

        >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")

        >>> sentence = "Paris is one of the densest populated areas in Europe."
        >>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids

        >>> outputs = model.generate(input_ids)
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
1120
        ```"""
1121
1122
        # 1. Set generation parameters if not already defined
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
1123
        num_beams = num_beams if num_beams is not None else self.config.num_beams
1124
1125
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
1126
        num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
1127
1128
1129
1130
1131
        do_sample = do_sample if do_sample is not None else self.config.do_sample
        num_return_sequences = (
            num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
        )

1132
1133
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1134

1135
1136
        if eos_token_id is None and hasattr(self.config, "decoder"):
            eos_token_id = self.config.decoder.eos_token_id
1137

1138
1139
1140
1141
1142
        if pad_token_id is None and eos_token_id is not None:
            # special case if pad_token_id is not defined
            logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
            pad_token_id = eos_token_id

1143
1144
1145
1146
1147
1148
1149
1150
1151
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

1152
1153
1154
1155
1156
1157
1158
1159
1160
        # 2. Define model inputs
        # inputs_tensor has to be defined
        # model_input_name is defined if model-specific keyword input is passed
        # otherwise model_input_name is None
        # all model-specific keyword inputs are removed from `model_kwargs`
        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs)
        batch_size = inputs_tensor.shape[0]

        # 3. Define other model kwargs
1161
1162
        model_kwargs["output_attentions"] = output_attentions
        model_kwargs["output_hidden_states"] = output_hidden_states
1163
        model_kwargs["use_cache"] = use_cache
1164

1165
1166
1167
1168
        accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
        requires_attention_mask = "encoder_outputs" not in model_kwargs

        if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
1169
            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
1170
                inputs_tensor, pad_token_id, eos_token_id
1171
            )
1172

1173
        if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
1174
1175
1176
1177
1178
            # if model is encoder decoder encoder_outputs are created
            # and added to `model_kwargs`
            model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
                inputs_tensor, model_kwargs, model_input_name
            )
1179

1180
1181
1182
1183
1184
1185
1186
        # 4. Prepare `input_ids` which will be used for auto-regressive generation
        if self.config.is_encoder_decoder:
            input_ids = self._prepare_decoder_input_ids_for_generation(
                batch_size,
                decoder_start_token_id=decoder_start_token_id,
                bos_token_id=bos_token_id,
                model_kwargs=model_kwargs,
1187
                device=inputs_tensor.device,
1188
            )
1189
        else:
1190
1191
            # if decoder-only then inputs_tensor has to be `input_ids`
            input_ids = inputs_tensor
1192

1193
1194
        input_ids_seq_length = input_ids.shape[-1]

1195
        # 5. Prepare `max_length` depending on other stopping criteria
1196
1197
        # if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens`
        if max_length is None and max_new_tokens is not None:
1198
            max_length = max_new_tokens + input_ids_seq_length
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
        elif max_length is not None and max_new_tokens is not None:
            # Both are set, this is odd, raise a warning
            warnings.warn(
                "Both `max_length` and `max_new_tokens` have been set "
                f"but they serve the same purpose. `max_length` {max_length} "
                f"will take priority over `max_new_tokens` {max_new_tokens}.",
                UserWarning,
            )
        # default to config if still None
        max_length = max_length if max_length is not None else self.config.max_length
1209
        min_length = min_length if min_length is not None else self.config.min_length
1210

1211
1212
1213
1214
1215
        if min_length is not None and min_length > max_length:
            raise ValueError(
                f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
                f"length ({max_length})"
            )
1216
        if input_ids_seq_length >= max_length:
1217
1218
            input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
            logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
1219
1220
1221
                f"Input length of {input_ids_string} is {input_ids_seq_length}, but ``max_length`` is set to"
                f" {max_length}. This can lead to unexpected behavior. You should consider increasing"
                " ``config.max_length`` or ``max_length``."
1222
1223
            )

1224
        # 6. determine generation mode
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
        is_constraint_gen_mode = constraints is not None or force_words_ids is not None
        is_greedy_gen_mode = (
            (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
        )
        is_sample_gen_mode = (
            (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
        )
        is_beam_gen_mode = (
            (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
        )
1235
        is_beam_sample_gen_mode = (
1236
            (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
1237
        )
1238
        is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode
1239

1240
1241
1242
1243
1244
1245
        if num_beam_groups > num_beams:
            raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
        if is_group_beam_gen_mode and do_sample is True:
            raise ValueError(
                "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
            )
1246

1247
        # 7. prepare distribution pre_processing samplers
1248
1249
1250
        logits_processor = self._get_logits_processor(
            repetition_penalty=repetition_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
1251
            encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
1252
            input_ids_seq_length=input_ids_seq_length,
1253
            encoder_input_ids=inputs_tensor,
1254
1255
            bad_words_ids=bad_words_ids,
            min_length=min_length,
1256
            max_length=max_length,
1257
            eos_token_id=eos_token_id,
1258
1259
            forced_bos_token_id=forced_bos_token_id,
            forced_eos_token_id=forced_eos_token_id,
1260
1261
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            num_beams=num_beams,
1262
1263
            num_beam_groups=num_beam_groups,
            diversity_penalty=diversity_penalty,
1264
            remove_invalid_values=remove_invalid_values,
1265
            exponential_decay_length_penalty=exponential_decay_length_penalty,
1266
            logits_processor=logits_processor,
1267
            renormalize_logits=renormalize_logits,
1268
        )
1269

1270
        # 8. prepare stopping criteria
1271
1272
1273
        stopping_criteria = self._get_stopping_criteria(
            max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria
        )
1274

1275
        # 9. go into different generation modes
1276
1277
1278
1279
        if is_greedy_gen_mode:
            if num_return_sequences > 1:
                raise ValueError(
                    f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
1280
                )
1281

1282
            # 10. run greedy search
1283
1284
1285
            return self.greedy_search(
                input_ids,
                logits_processor=logits_processor,
1286
                stopping_criteria=stopping_criteria,
1287
1288
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1289
1290
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1291
                synced_gpus=synced_gpus,
1292
                **model_kwargs,
1293
            )
1294

1295
        elif is_sample_gen_mode:
1296
            # 10. prepare logits warper
1297
            logits_warper = self._get_logits_warper(
1298
1299
1300
1301
1302
1303
                top_k=top_k,
                top_p=top_p,
                typical_p=typical_p,
                temperature=temperature,
                num_beams=num_beams,
                renormalize_logits=renormalize_logits,
1304
1305
            )

1306
            # 11. expand input_ids with `num_return_sequences` additional sequences per batch
1307
1308
1309
1310
1311
1312
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids,
                expand_size=num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )
1313

1314
            # 12. run sample
1315
1316
1317
1318
            return self.sample(
                input_ids,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
1319
                stopping_criteria=stopping_criteria,
1320
1321
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1322
1323
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1324
                synced_gpus=synced_gpus,
1325
1326
                **model_kwargs,
            )
1327

1328
1329
1330
        elif is_beam_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
1331

1332
1333
1334
            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

1335
            # 10. prepare beam search scorer
1336
1337
1338
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=num_beams,
1339
                device=inputs_tensor.device,
1340
1341
1342
1343
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
            )
1344
            # 11. interleave input_ids with `num_beams` additional sequences per batch
1345
1346
1347
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
1348
            # 12. run beam search
1349
            return self.beam_search(
1350
                input_ids,
1351
1352
                beam_scorer,
                logits_processor=logits_processor,
1353
                stopping_criteria=stopping_criteria,
1354
1355
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1356
1357
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1358
                synced_gpus=synced_gpus,
1359
1360
1361
1362
                **model_kwargs,
            )

        elif is_beam_sample_gen_mode:
1363
            # 10. prepare logits warper
1364
            logits_warper = self._get_logits_warper(
1365
1366
1367
1368
1369
1370
                top_k=top_k,
                top_p=top_p,
                typical_p=typical_p,
                temperature=temperature,
                num_beams=num_beams,
                renormalize_logits=renormalize_logits,
1371
1372
            )

1373
1374
            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")
1375
            # 11. prepare beam search scorer
1376
            beam_scorer = BeamSearchScorer(
1377
                batch_size=batch_size * num_return_sequences,
1378
                num_beams=num_beams,
1379
                device=inputs_tensor.device,
1380
1381
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
1382
            )
1383

1384
            # 12. interleave input_ids with `num_beams` additional sequences per batch
1385
1386
1387
1388
1389
1390
1391
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids,
                expand_size=num_beams * num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )

1392
            # 13. run beam sample
1393
            return self.beam_sample(
1394
                input_ids,
1395
1396
1397
                beam_scorer,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
1398
                stopping_criteria=stopping_criteria,
1399
1400
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1401
1402
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1403
                synced_gpus=synced_gpus,
1404
                **model_kwargs,
1405
1406
            )

1407
1408
1409
1410
1411
1412
1413
        elif is_group_beam_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

            if num_beams % num_beam_groups != 0:
                raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")

1414
1415
1416
            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

1417
1418
            # 10. prepare beam search scorer
            beam_scorer = BeamSearchScorer(
1419
1420
                batch_size=batch_size,
                num_beams=num_beams,
1421
                max_length=stopping_criteria.max_length,
1422
                device=inputs_tensor.device,
1423
1424
1425
1426
1427
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
                num_beam_groups=num_beam_groups,
            )
1428
            # 11. interleave input_ids with `num_beams` additional sequences per batch
1429
1430
1431
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
1432
            # 12. run beam search
1433
1434
            return self.group_beam_search(
                input_ids,
1435
                beam_scorer,
1436
                logits_processor=logits_processor,
1437
                stopping_criteria=stopping_criteria,
1438
1439
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1440
1441
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1442
                synced_gpus=synced_gpus,
1443
1444
1445
                **model_kwargs,
            )

1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
        elif is_constraint_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

            if num_beams <= 1:
                raise ValueError("`num_beams` needs to be greater than 1 for constrained genertation.")

            if do_sample:
                raise ValueError("`do_sample` needs to be false for constrained generation.")

            if num_beam_groups is not None and num_beam_groups > 1:
                raise ValueError("`num_beam_groups` not supported yet for constrained generation.")

1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
            final_constraints = []
            if constraints is not None:
                final_constraints = constraints

            if force_words_ids is not None:

                def typeerror():
                    raise ValueError(
                        "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
                        f"of positive integers, but is {force_words_ids}."
                    )

                if not isinstance(force_words_ids, list) or len(force_words_ids) == 0:
                    typeerror()

                for word_ids in force_words_ids:
                    if isinstance(word_ids[0], list):
                        if not isinstance(word_ids, list) or len(word_ids) == 0:
                            typeerror()
                        if any(not isinstance(token_ids, list) for token_ids in word_ids):
                            typeerror()
                        if any(
                            any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
                            for token_ids in word_ids
                        ):
                            typeerror()

                        constraint = DisjunctiveConstraint(word_ids)
                    else:
                        if not isinstance(word_ids, list) or len(word_ids) == 0:
                            typeerror()
                        if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
                            typeerror()

                        constraint = PhrasalConstraint(word_ids)
                    final_constraints.append(constraint)

1499
1500
            # 10. prepare beam search scorer
            constrained_beam_scorer = ConstrainedBeamSearchScorer(
1501
                constraints=final_constraints,
1502
1503
                batch_size=batch_size,
                num_beams=num_beams,
1504
                device=inputs_tensor.device,
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
            )
            # 11. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
            # 12. run beam search
            return self.constrained_beam_search(
                input_ids,
                constrained_beam_scorer=constrained_beam_scorer,
                logits_processor=logits_processor,
                stopping_criteria=stopping_criteria,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

1527
    def greedy_search(
1528
        self,
1529
1530
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
1531
        stopping_criteria: Optional[StoppingCriteriaList] = None,
1532
1533
1534
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
1535
1536
1537
1538
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
1539
        synced_gpus: Optional[bool] = False,
1540
1541
        **model_kwargs,
    ) -> Union[GreedySearchOutput, torch.LongTensor]:
1542
        r"""
1543
1544
        Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
        used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
1545
1546
1547

        Parameters:

1548
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1549
                The sequence used as a prompt for the generation.
1550
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1551
1552
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
1553
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1554
1555
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
1556

1557
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
1558
1559
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
1560
1561
1562
1563
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
1564
            output_attentions (`bool`, *optional*, defaults to `False`):
1565
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1566
                returned tensors for more details.
1567
            output_hidden_states (`bool`, *optional*, defaults to `False`):
1568
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1569
                for more details.
1570
            output_scores (`bool`, *optional*, defaults to `False`):
1571
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
1572
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
1573
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1574
            synced_gpus (`bool`, *optional*, defaults to `False`):
1575
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
1576
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
1577
1578
                Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
                If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
1579
1580

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1581
            [`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`]
Stas Bekman's avatar
Stas Bekman committed
1582
            or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
1583
1584
            [`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if
1585
1586
1587
1588
1589
1590
            `model.config.is_encoder_decoder=True`.

        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
1591
1592
1593
1594
        ...     AutoTokenizer,
        ...     AutoModelForCausalLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
1595
1596
        ...     StoppingCriteriaList,
        ...     MaxLengthCriteria,
1597
1598
1599
1600
1601
1602
1603
1604
        ... )

        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")

        >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
        >>> model.config.pad_token_id = model.config.eos_token_id

1605
        >>> input_prompt = "It might be possible to"
1606
1607
1608
        >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
1609
1610
        >>> logits_processor = LogitsProcessorList(
        ...     [
1611
        ...         MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id),
Sylvain Gugger's avatar
Sylvain Gugger committed
1612
1613
        ...     ]
        ... )
1614
        >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
1615

1616
1617
1618
        >>> outputs = model.greedy_search(
        ...     input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
        ... )
1619

1620
1621
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ["It might be possible to get a better understanding of the nature of the problem, but it's not"]
1622
        ```"""
1623
1624
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1625
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1626
1627
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
1628
1629
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
1630
1631
1632
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
1633
1634
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
1647
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
1648
1649
1650
1651
1652
1653
1654
1655
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
1656

1657
1658
1659
        # keep track of which sequences are already finished
        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
        cur_len = input_ids.shape[-1]
1660

1661
        this_peer_finished = False  # used by synced_gpus only
1662
        while True:
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

1674
1675
            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1676

1677
            # forward pass to get next token
1678
1679
1680
1681
1682
1683
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
1684
1685
1686
1687
1688

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

1689
            next_token_logits = outputs.logits[:, -1, :]
1690

1691
1692
1693
1694
1695
1696
1697
1698
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_token_logits,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
1699
1700
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
1701
1702
1703
1704
1705
1706
1707
1708

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

1709
            # pre-process distribution
1710
            next_tokens_scores = logits_processor(input_ids, next_token_logits)
1711
1712

            # argmax
1713
            next_tokens = torch.argmax(next_tokens_scores, dim=-1)
1714

1715
            # finished sentences should have their next token be a padding token
1716
            if eos_token_id is not None:
1717
1718
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
1719
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
1720

1721
            # update generated ids, model inputs, and length for next step
1722
1723
1724
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1725
            )
1726
1727
            cur_len = cur_len + 1

1728
1729
1730
1731
1732
            # if eos_token was found in one sentence, set sentence to finished
            if eos_token_id is not None:
                unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

            # stop when each sentence is finished, or if we exceed the maximum length
1733
1734
1735
1736
1737
1738
            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True

1739
1740
1741
1742
1743
1744
1745
1746
        if return_dict_in_generate:
            if self.config.is_encoder_decoder:
                return GreedySearchEncoderDecoderOutput(
                    sequences=input_ids,
                    scores=scores,
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
1747
                    cross_attentions=cross_attentions,
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return GreedySearchDecoderOnlyOutput(
                    sequences=input_ids,
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return input_ids
1759
1760
1761
1762
1763

    def sample(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
1764
        stopping_criteria: Optional[StoppingCriteriaList] = None,
1765
1766
1767
1768
        logits_warper: Optional[LogitsProcessorList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
1769
1770
1771
1772
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
1773
        synced_gpus: Optional[bool] = False,
1774
1775
        **model_kwargs,
    ) -> Union[SampleOutput, torch.LongTensor]:
1776
        r"""
1777
1778
        Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
1779
1780
1781

        Parameters:

1782
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1783
                The sequence used as a prompt for the generation.
1784
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1785
1786
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
1787
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1788
1789
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
1790
            logits_warper (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1791
1792
1793
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
                to warp the prediction score distribution of the language modeling head applied before multinomial
                sampling at each generation step.
1794
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
1795
1796
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
1797
1798
1799
1800
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
1801
            output_attentions (`bool`, *optional*, defaults to `False`):
1802
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1803
                returned tensors for more details.
1804
            output_hidden_states (`bool`, *optional*, defaults to `False`):
1805
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1806
                for more details.
1807
            output_scores (`bool`, *optional*, defaults to `False`):
1808
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
1809
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
1810
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1811
            synced_gpus (`bool`, *optional*, defaults to `False`):
1812
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
1813
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
1814
1815
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.
1816
1817

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1818
            [`~generation_utils.SampleDecoderOnlyOutput`], [`~generation_utils.SampleEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
1819
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
1820
1821
            [`~generation_utils.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.SampleEncoderDecoderOutput`] if
1822
1823
1824
1825
1826
1827
            `model.config.is_encoder_decoder=True`.

        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
1828
1829
1830
1831
1832
1833
        ...     AutoTokenizer,
        ...     AutoModelForCausalLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     TopKLogitsWarper,
        ...     TemperatureLogitsWarper,
1834
1835
        ...     StoppingCriteriaList,
        ...     MaxLengthCriteria,
1836
        ... )
1837
        >>> import torch
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848

        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")

        >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
        >>> model.config.pad_token_id = model.config.eos_token_id

        >>> input_prompt = "Today is a beautiful day, and"
        >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
1849
1850
1851
1852
1853
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )
1854
        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
1855
1856
1857
1858
1859
1860
        >>> logits_warper = LogitsProcessorList(
        ...     [
        ...         TopKLogitsWarper(50),
        ...         TemperatureLogitsWarper(0.7),
        ...     ]
        ... )
1861

1862
1863
1864
1865
1866
1867
1868
1869
1870
        >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])

        >>> torch.manual_seed(0)  # doctest: +IGNORE_RESULT
        >>> outputs = model.sample(
        ...     input_ids,
        ...     logits_processor=logits_processor,
        ...     logits_warper=logits_warper,
        ...     stopping_criteria=stopping_criteria,
        ... )
1871

1872
1873
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
1874
        ```"""
1875
1876
1877

        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1878
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1879
1880
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
1881
1882
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
1883
1884
1885
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
1886
        logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
1887
1888
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
1901
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
1902
1903
1904
1905
1906
1907
1908
1909
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
1910

1911
1912
1913
        # keep track of which sequences are already finished
        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
        cur_len = input_ids.shape[-1]
1914

1915
        this_peer_finished = False  # used by synced_gpus only
1916
        # auto-regressive generation
1917
        while True:
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

1929
1930
1931
1932
            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # forward pass to get next token
1933
1934
1935
1936
1937
1938
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
1939
1940
1941
1942
1943

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

1944
1945
1946
            next_token_logits = outputs.logits[:, -1, :]

            # pre-process distribution
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_token_scores = logits_warper(input_ids, next_token_scores)

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_token_scores,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
1958
1959
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
1960
1961
1962
1963
1964
1965
1966

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )
1967
1968

            # sample
1969
            probs = nn.functional.softmax(next_token_scores, dim=-1)
1970
1971
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

1972
            # finished sentences should have their next token be a padding token
1973
            if eos_token_id is not None:
1974
1975
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
1976
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
1977

1978
            # update generated ids, model inputs, and length for next step
1979
1980
1981
1982
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
1983
1984
            cur_len = cur_len + 1

1985
1986
1987
1988
1989
            # if eos_token was found in one sentence, set sentence to finished
            if eos_token_id is not None:
                unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

            # stop when each sentence is finished, or if we exceed the maximum length
1990
1991
1992
1993
1994
1995
            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True

1996
1997
1998
1999
2000
2001
2002
2003
        if return_dict_in_generate:
            if self.config.is_encoder_decoder:
                return SampleEncoderDecoderOutput(
                    sequences=input_ids,
                    scores=scores,
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
2004
                    cross_attentions=cross_attentions,
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return SampleDecoderOnlyOutput(
                    sequences=input_ids,
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return input_ids
2016

2017
    def beam_search(
2018
        self,
2019
2020
2021
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
2022
        stopping_criteria: Optional[StoppingCriteriaList] = None,
2023
2024
2025
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
2026
2027
2028
2029
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
2030
        synced_gpus: Optional[bool] = False,
2031
2032
        **model_kwargs,
    ) -> Union[BeamSearchOutput, torch.LongTensor]:
2033
        r"""
2034
2035
        Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
2036

2037
        Parameters:
2038

2039
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
2040
                The sequence used as a prompt for the generation.
2041
            beam_scorer (`BeamScorer`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2042
2043
                An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
2044
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2045
2046
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
2047
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2048
2049
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
2050
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
2051
2052
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
2053
2054
2055
2056
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
2057
            output_attentions (`bool`, *optional*, defaults to `False`):
2058
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2059
                returned tensors for more details.
2060
            output_hidden_states (`bool`, *optional*, defaults to `False`):
2061
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
2062
                for more details.
2063
            output_scores (`bool`, *optional*, defaults to `False`):
2064
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
2065
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
2066
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2067
            synced_gpus (`bool`, *optional*, defaults to `False`):
2068
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
2069
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
2070
2071
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.
2072

2073
        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2074
            [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
2075
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2076
2077
            [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if
2078
2079
2080
2081
2082
2083
2084
            `model.config.is_encoder_decoder=True`.


        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
2085
2086
2087
2088
2089
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     BeamSearchScorer,
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


        >>> # lets run beam search using 3 beams
        >>> num_beams = 3
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
Sylvain Gugger's avatar
Sylvain Gugger committed
2108
2109
2110
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
        ... }

        >>> # instantiate beam scorer
        >>> beam_scorer = BeamSearchScorer(
        ...     batch_size=1,
        ...     num_beams=num_beams,
        ...     device=model.device,
        ... )

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2121
2122
2123
2124
2125
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )
2126
2127
2128

        >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)

2129
2130
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt bist du?']
2131
        ```"""
2132
2133
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
2134
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2135
2136
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
2137
2138
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
2139
2140
2141
2142
2143
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
        if len(stopping_criteria) == 0:
            warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
2144
2145
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
2146
2147
2148
2149
2150
2151
2152
2153
2154
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape

        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )

2165
2166
        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
2167
2168
2169
        beam_indices = (
            tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
        )
2170
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
2171
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
2172
2173
2174
2175
2176
2177
2178
2179
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
2180

2181
2182
2183
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view((batch_size * num_beams,))
2184

2185
        this_peer_finished = False  # used by synced_gpus only
2186
        while True:
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

2198
2199
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

2200
2201
2202
2203
2204
2205
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
2206
2207
2208
2209
2210

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

2211
            next_token_logits = outputs.logits[:, -1, :]
2212
            # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
2213
            # cannot be generated both before and after the `nn.functional.log_softmax` operation.
2214
            next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
2215
2216
2217
            next_token_scores = nn.functional.log_softmax(
                next_token_logits, dim=-1
            )  # (batch_size * num_beams, vocab_size)
2218

2219
2220
            next_token_scores_processed = logits_processor(input_ids, next_token_scores)
            next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
2221
2222
2223
2224

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
2225
                    scores += (next_token_scores_processed,)
2226
2227
2228
2229
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
2230
2231
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
2232
2233
2234
2235
2236
2237
2238
2239

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

2240
2241
2242
2243
2244
2245
            # reshape for beam search
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

            next_token_scores, next_tokens = torch.topk(
                next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
2246
2247
            )

2248
            next_indices = torch_int_div(next_tokens, vocab_size)
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
            next_tokens = next_tokens % vocab_size

            # stateless
            beam_outputs = beam_scorer.process(
                input_ids,
                next_token_scores,
                next_tokens,
                next_indices,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
2259
            )
2260

2261
2262
2263
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]
2264

2265
            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
2266

2267
2268
2269
2270
2271
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
2272

2273
2274
2275
            if return_dict_in_generate and output_scores:
                beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

2276
2277
            # increase cur_len
            cur_len = cur_len + 1
2278

2279
2280
2281
2282
2283
            if beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True
2284

2285
        sequence_outputs = beam_scorer.finalize(
2286
2287
2288
2289
2290
2291
2292
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
2293
        )
2294

2295
2296
2297
        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None
2298
2299
2300
2301
2302
2303
2304
2305
            else:
                num_return_sequences = beam_scorer.num_beam_hyps_to_keep
                # return only as many indices as sequences
                beam_indices = tuple(
                    (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
                )
                beam_indices = sum(beam_indices, ())

2306
2307
2308
2309
2310
            if self.config.is_encoder_decoder:
                return BeamSearchEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2311
                    beam_indices=beam_indices,
2312
2313
2314
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
2315
                    cross_attentions=cross_attentions,
2316
2317
2318
2319
2320
2321
2322
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return BeamSearchDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2323
                    beam_indices=beam_indices,
2324
2325
2326
2327
2328
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]
2329

2330
2331
2332
2333
2334
    def beam_sample(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
2335
        stopping_criteria: Optional[StoppingCriteriaList] = None,
2336
2337
2338
2339
        logits_warper: Optional[LogitsProcessorList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
2340
2341
2342
2343
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
2344
        synced_gpus: Optional[bool] = False,
2345
2346
        **model_kwargs,
    ) -> Union[BeamSampleOutput, torch.LongTensor]:
2347
        r"""
2348
2349
        Generates sequences of token ids for models with a language modeling head using **beam search multinomial
        sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
2350

2351
        Parameters:
2352

2353
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
2354
                The sequence used as a prompt for the generation.
2355
            beam_scorer (`BeamScorer`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2356
2357
                A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
2358
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2359
2360
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
2361
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2362
2363
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
2364
            logits_warper (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2365
2366
2367
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
                to warp the prediction score distribution of the language modeling head applied before multinomial
                sampling at each generation step.
2368
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
2369
2370
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
2371
2372
2373
2374
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
2375
            output_attentions (`bool`, *optional*, defaults to `False`):
2376
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2377
                returned tensors for more details.
2378
            output_hidden_states (`bool`, *optional*, defaults to `False`):
2379
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
2380
                for more details.
2381
            output_scores (`bool`, *optional*, defaults to `False`):
2382
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
2383
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
2384
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2385
            synced_gpus (`bool`, *optional*, defaults to `False`):
2386
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
2387
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
2388
2389
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.
2390

2391
        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2392
            [`~generation_utils.BeamSampleDecoderOnlyOutput`], [`~generation_utils.BeamSampleEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
2393
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2394
2395
            [`~generation_utils.BeamSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.BeamSampleEncoderDecoderOutput`] if
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
            `model.config.is_encoder_decoder=True`.

        Examples:

        ```python
        >>> from transformers import (
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     TopKLogitsWarper,
        ...     TemperatureLogitsWarper,
        ...     BeamSearchScorer,
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids

        >>> # lets run beam search using 3 beams
        >>> num_beams = 3
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
Sylvain Gugger's avatar
Sylvain Gugger committed
2426
2427
2428
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
        ... }

        >>> # instantiate beam scorer
        >>> beam_scorer = BeamSearchScorer(
        ...     batch_size=1,
        ...     max_length=model.config.max_length,
        ...     num_beams=num_beams,
        ...     device=model.device,
        ... )

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2440
2441
2442
        >>> logits_processor = LogitsProcessorList(
        ...     [MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)]
        ... )
2443
        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2444
2445
2446
2447
2448
2449
        >>> logits_warper = LogitsProcessorList(
        ...     [
        ...         TopKLogitsWarper(50),
        ...         TemperatureLogitsWarper(0.7),
        ...     ]
        ... )
2450
2451
2452
2453
2454

        >>> outputs = model.beam_sample(
        ...     input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
        ... )

2455
2456
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt bist du?']
2457
        ```"""
2458
2459
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
2460
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2461
2462
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
2463
2464
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
2465
2466
2467
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
2468
2469
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
2470
2471
2472
2473
2474
2475
2476
2477
2478
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

2479
2480
2481
2482
2483
        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape

2484
2485
        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
2486
2487
2488
        beam_indices = (
            tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
        )
2489
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
2490
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
2491
2492
2493
2494
2495
2496
2497
2498
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
2499

2500
2501
2502
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
        beam_scores = beam_scores.view((batch_size * num_beams,))

2503
        this_peer_finished = False  # used by synced_gpus only
2504
        while True:
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

2516
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2517

2518
2519
2520
2521
2522
2523
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
2524
2525
2526
2527
2528

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

2529
            next_token_logits = outputs.logits[:, -1, :]
2530

2531
            # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
2532
            # cannot be generated both before and after the `nn.functional.log_softmax` operation.
2533
            next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
2534
2535
2536
            next_token_scores = nn.functional.log_softmax(
                next_token_logits, dim=-1
            )  # (batch_size * num_beams, vocab_size)
2537

2538
2539
            next_token_scores_processed = logits_processor(input_ids, next_token_scores)
            next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
2540
            next_token_scores = logits_warper(input_ids, next_token_scores)
2541

2542
2543
2544
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
2545
                    scores += (logits_warper(input_ids, next_token_scores_processed),)
2546
2547
2548
2549
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
2550
2551
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
2552
2553
2554
2555
2556
2557
2558
2559

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

2560
2561
2562
            # reshape for beam search
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
2563

2564
            probs = nn.functional.softmax(next_token_scores, dim=-1)
2565

2566
2567
            next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
            next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
2568

2569
2570
            next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
            next_tokens = torch.gather(next_tokens, -1, _indices)
2571

2572
            next_indices = torch_int_div(next_tokens, vocab_size)
2573
            next_tokens = next_tokens % vocab_size
Sylvain Gugger's avatar
Sylvain Gugger committed
2574

2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
            # stateless
            beam_outputs = beam_scorer.process(
                input_ids,
                next_token_scores,
                next_tokens,
                next_indices,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
            )
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]

            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)

2596
2597
2598
            if return_dict_in_generate and output_scores:
                beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

2599
2600
            # increase cur_len
            cur_len = cur_len + 1
2601

2602
2603
2604
2605
2606
            if beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True
2607

2608
        sequence_outputs = beam_scorer.finalize(
2609
2610
2611
2612
2613
2614
2615
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
2616
2617
        )

2618
2619
2620
        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None
2621
2622
2623
2624
2625
2626
2627
2628
            else:
                num_return_sequences = beam_scorer.num_beam_hyps_to_keep
                # return only as many indices as sequences
                beam_indices = tuple(
                    (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
                )
                beam_indices = sum(beam_indices, ())

2629
            if self.config.is_encoder_decoder:
2630
                return BeamSampleEncoderDecoderOutput(
2631
2632
2633
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2634
                    beam_indices=beam_indices,
2635
2636
2637
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
2638
                    cross_attentions=cross_attentions,
2639
2640
2641
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
2642
                return BeamSampleDecoderOnlyOutput(
2643
2644
2645
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2646
                    beam_indices=beam_indices,
2647
2648
2649
2650
2651
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]
2652

2653
2654
2655
2656
2657
    def group_beam_search(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
2658
        stopping_criteria: Optional[StoppingCriteriaList] = None,
2659
2660
2661
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
2662
2663
2664
2665
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
2666
        synced_gpus: Optional[bool] = False,
2667
        **model_kwargs,
2668
2669
    ):
        r"""
2670
2671
        Generates sequences of token ids for models with a language modeling head using **diverse beam search
        decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
2672
2673
2674

        Parameters:

2675
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
2676
                The sequence used as a prompt for the generation.
2677
            beam_scorer (`BeamScorer`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2678
2679
                An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
2680
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2681
2682
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
2683
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2684
2685
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
2686
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
2687
2688
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
2689
2690
2691
2692
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
2693
            output_attentions (`bool`, *optional*, defaults to `False`):
2694
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2695
                returned tensors for more details.
2696
            output_hidden_states (`bool`, *optional*, defaults to `False`):
2697
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
2698
                for more details.
2699
            output_scores (`bool`, *optional*, defaults to `False`):
2700
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
2701
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
2702
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2703
            synced_gpus (`bool`, *optional*, defaults to `False`):
2704
2705
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)

2706
            model_kwargs:
2707
2708
                Additional model specific kwargs that will be forwarded to the `forward` function of the model. If
                model is an encoder-decoder model the kwargs should include `encoder_outputs`.
2709
2710

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2711
            [`~generation_utils.BeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
2712
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2713
            [`~generation_utils.BeamSearchDecoderOnlyOutput`] if [`~generation_utils.BeamSearchDecoderOnlyOutput`] if
2714
            `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2715
            [`~generation_utils.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`.
2716
2717
2718
2719
2720

        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
2721
2722
2723
2724
2725
2726
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     HammingDiversityLogitsProcessor,
        ...     BeamSearchScorer,
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


        >>> # lets run diverse beam search using 6 beams
        >>> num_beams = 6
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
Sylvain Gugger's avatar
Sylvain Gugger committed
2745
2746
2747
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
2748
2749
2750
2751
2752
2753
2754
2755
        ... }

        >>> # instantiate beam scorer
        >>> beam_scorer = BeamSearchScorer(
        ...     batch_size=1,
        ...     max_length=model.config.max_length,
        ...     num_beams=num_beams,
        ...     device=model.device,
Sylvain Gugger's avatar
Sylvain Gugger committed
2756
        ...     num_beam_groups=3,
2757
2758
2759
        ... )

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2760
2761
2762
2763
2764
2765
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3),
        ...         MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )
2766

Sylvain Gugger's avatar
Sylvain Gugger committed
2767
2768
2769
        >>> outputs = model.group_beam_search(
        ...     input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
        ... )
2770

2771
2772
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt bist du?']
2773
        ```"""
2774
2775
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
2776
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2777
2778
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
2779
2780
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
2781
2782
2783
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
2784
2785
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
2786
2787
2788
2789
2790
2791
2792
2793
2794
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

2795
2796
2797
2798
2799
2800
2801
2802
        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams
        num_beam_groups = beam_scorer.num_beam_groups
        num_sub_beams = num_beams // num_beam_groups
        device = input_ids.device

        batch_beam_size, cur_len = input_ids.shape

2803
2804
2805
2806
2807
        if return_dict_in_generate and output_scores:
            beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
        else:
            beam_indices = None

2808
2809
2810
2811
        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )
2812

2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )

2826
2827
2828
2829
2830
2831
        beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
        # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
        # the same group don't produce same tokens everytime.
        beam_scores[:, ::num_sub_beams] = 0
        beam_scores = beam_scores.view((batch_size * num_beams,))

2832
        this_peer_finished = False  # used by synced_gpus only
2833
        while True:
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

2845
2846
2847
2848
2849
2850
2851
2852
            # predicted tokens in cur_len step
            current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)

            # indices which will form the beams in the next time step
            reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)

            # do one decoder step on all beams of all sentences in batch
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2853
2854
2855
2856
2857
2858
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
2859

2860
2861
2862
2863
            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

2864
2865
2866
            if output_scores:
                processed_score = torch.zeros_like(outputs.logits[:, -1, :])

2867
2868
2869
2870
2871
2872
2873
            for beam_group_idx in range(num_beam_groups):
                group_start_idx = beam_group_idx * num_sub_beams
                group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
                group_size = group_end_idx - group_start_idx

                # indices of beams of current group among all sentences in batch
                batch_group_indices = []
2874

2875
2876
2877
2878
2879
2880
2881
2882
2883
                for batch_idx in range(batch_size):
                    batch_group_indices.extend(
                        [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
                    )
                group_input_ids = input_ids[batch_group_indices]

                # select outputs of beams of current group only
                next_token_logits = outputs.logits[batch_group_indices, -1, :]

2884
                # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
2885
                # cannot be generated both before and after the `nn.functional.log_softmax` operation.
2886
                next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
2887
2888
2889
                next_token_scores = nn.functional.log_softmax(
                    next_token_logits, dim=-1
                )  # (batch_size * group_size, vocab_size)
2890
2891
                vocab_size = next_token_scores.shape[-1]

2892
                next_token_scores_processed = logits_processor(
2893
2894
                    group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
                )
2895
2896
                next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
                next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
2897

2898
                if output_scores:
2899
                    processed_score[batch_group_indices] = next_token_scores_processed
2900
2901

                # reshape for beam search
2902
2903
2904
2905
2906
2907
                next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)

                next_token_scores, next_tokens = torch.topk(
                    next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
                )

2908
                next_indices = torch_int_div(next_tokens, vocab_size)
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
                next_tokens = next_tokens % vocab_size

                # stateless
                beam_outputs = beam_scorer.process(
                    group_input_ids,
                    next_token_scores,
                    next_tokens,
                    next_indices,
                    pad_token_id=pad_token_id,
                    eos_token_id=eos_token_id,
                )
                beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
                beam_next_tokens = beam_outputs["next_beam_tokens"]
                beam_idx = beam_outputs["next_beam_indices"]

2924
2925
2926
2927
2928
                if return_dict_in_generate and output_scores:
                    beam_indices[beam_group_idx] = tuple(
                        beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
                    )

2929
2930
2931
2932
2933
2934
2935
                input_ids[batch_group_indices] = group_input_ids[beam_idx]
                group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
                current_tokens[batch_group_indices] = group_input_ids[:, -1]

                # (beam_idx // group_size) -> batch_idx
                # (beam_idx % group_size) -> offset of idx inside the group
                reordering_indices[batch_group_indices] = (
2936
                    num_beams * torch_int_div(beam_idx, group_size) + group_start_idx + (beam_idx % group_size)
2937
2938
                )

2939
2940
2941
2942
2943
2944
2945
2946
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (processed_score,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
2947
2948
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
2949
2950
2951
2952
2953
2954
2955
2956

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

2957
2958
            input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

2959
2960
2961
2962
2963
2964
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices)

2965
            # increase cur_len
2966
2967
            cur_len = cur_len + 1

2968
2969
2970
2971
2972
            if beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True
2973

2974
        sequence_outputs = beam_scorer.finalize(
2975
2976
2977
2978
2979
2980
2981
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
2982
2983
        )

2984
2985
        if return_dict_in_generate:
            if not output_scores:
2986
                sequence_outputs["sequence_scores"] = None
2987
2988
2989
2990
2991
2992
2993
2994
2995
            else:
                beam_indices = sum(beam_indices, ())
                num_return_sequences = beam_scorer.num_beam_hyps_to_keep
                # return only as many indices as sequences
                beam_indices = tuple(
                    (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
                )
                beam_indices = sum(beam_indices, ())

2996
2997
2998
2999
3000
            if self.config.is_encoder_decoder:
                return BeamSearchEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
3001
                    beam_indices=beam_indices,
3002
                    encoder_attentions=encoder_attentions,
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
                    cross_attentions=cross_attentions,
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return BeamSearchDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]

    def constrained_beam_search(
        self,
        input_ids: torch.LongTensor,
        constrained_beam_scorer: ConstrainedBeamSearchScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: Optional[bool] = None,
        **model_kwargs,
    ) -> Union[BeamSearchOutput, torch.LongTensor]:

        r"""
3037
3038
        Generates sequences of token ids for models with a language modeling head using **constrained beam search
        decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072

        Parameters:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                The sequence used as a prompt for the generation.
            constrained_beam_scorer (`ConstrainedBeamSearchScorer`):
                A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation, while satisfying a list of positive constraints. For more information, the
                documentation of [`ConstrainedBeamSearchScorer`] should be read.
            logits_processor (`LogitsProcessorList`, *optional*):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
            stopping_criteria (`StoppingCriteriaList`, *optional*):
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
            logits_warper (`LogitsProcessorList`, *optional*):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
                to warp the prediction score distribution of the language modeling head applied before multinomial
                sampling at each generation step.
            max_length (`int`, *optional*, defaults to 20):
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more details.
            output_hidden_states (`bool`, *optional*, defaults to `False`):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more details.
            output_scores (`bool`, *optional*, defaults to `False`):
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
3073
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
            synced_gpus (`bool`, *optional*, defaults to `False`):
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
            model_kwargs:
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.

        Return:
            [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
            [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if
            `model.config.is_encoder_decoder=True`.


        Examples:

        ```python
        >>> from transformers import (
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     ConstrainedBeamSearchScorer,
        ...     PhrasalConstraint,
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


        >>> # lets run beam search using 3 beams
        >>> num_beams = 3
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
        ... }

3121
        >>> constraint_str = "Sie"
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
        >>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1]  # slice to remove eos token
        >>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)]


        >>> # instantiate beam scorer
        >>> beam_scorer = ConstrainedBeamSearchScorer(
        ...     batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
        ... )

        >>> # instantiate logits processors
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )

        >>> outputs = model.constrained_beam_search(
        ...     input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
        ... )

3142
3143
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt sind Sie?']
3144
3145
3146
3147
3148
3149
        ```"""
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
3150
3151
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298
3299
3300
3301
3302
3303
3304
3305
3306
3307
3308
3309
3310
3311
3312
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
        if len(stopping_criteria) == 0:
            warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )

        batch_size = len(constrained_beam_scorer._beam_hyps)
        num_beams = constrained_beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape

        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )

        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view((batch_size * num_beams,))

        this_peer_finished = False  # used by synced_gpus only
        while True:

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

            next_token_logits = outputs.logits[:, -1, :]
            # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
            # cannot be generated both before and after the `nn.functional.log_softmax` operation.
            next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
            next_token_scores = nn.functional.log_softmax(
                next_token_logits, dim=-1
            )  # (batch_size * num_beams, vocab_size)

            next_token_scores_processed = logits_processor(input_ids, next_token_scores)

            scores_for_all_vocab = next_token_scores_processed.clone()

            next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_token_scores,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

            # reshape for beam search
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

            next_token_scores, next_tokens = torch.topk(
                next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
            )

            next_indices = (next_tokens / vocab_size).long()
            next_tokens = next_tokens % vocab_size

            # stateless
            beam_outputs = constrained_beam_scorer.process(
                input_ids,
                next_token_scores,
                next_tokens,
                next_indices,
                scores_for_all_vocab,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
            )
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]

            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)

            # increase cur_len
            cur_len = cur_len + 1

            if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True

        sequence_outputs = constrained_beam_scorer.finalize(
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
        )

        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None
            if self.config.is_encoder_decoder:
                return BeamSearchEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    encoder_attentions=encoder_attentions,
3313
3314
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
3315
                    cross_attentions=cross_attentions,
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return BeamSearchDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]
3328

3329

3330
def top_k_top_p_filtering(
3331
    logits: torch.FloatTensor,
3332
3333
3334
3335
    top_k: int = 0,
    top_p: float = 1.0,
    filter_value: float = -float("Inf"),
    min_tokens_to_keep: int = 1,
3336
) -> torch.FloatTensor:
Sylvain Gugger's avatar
Sylvain Gugger committed
3337
    """
3338
    Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Sylvain Gugger's avatar
Sylvain Gugger committed
3339

Lysandre's avatar
Lysandre committed
3340
3341
    Args:
        logits: logits distribution shape (batch size, vocabulary size)
3342
        top_k (`int`, *optional*, defaults to 0):
3343
            If > 0, only keep the top k tokens with highest probability (top-k filtering)
3344
        top_p (`float`, *optional*, defaults to 1.0):
3345
3346
            If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
            filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
3347
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
3348
3349
            Minimumber of tokens we keep per batch example in the output.

Lysandre's avatar
Lysandre committed
3350
    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
3351
3352
    """
    if top_k > 0:
3353
3354
3355
        logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
            None, logits
        )
3356

3357
3358
    if 0 <= top_p <= 1.0:
        logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits)
3359

3360
    return logits