generation_utils.py 178 KB
Newer Older
1
# coding=utf-8
2
3
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

17
import inspect
18
import warnings
19
20
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
21
22

import torch
23
import torch.distributed as dist
24
from torch import nn
25

26
from .generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
27
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
28
from .generation_logits_process import (
29
    EncoderNoRepeatNGramLogitsProcessor,
30
    ExponentialDecayLengthPenalty,
31
32
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
33
    HammingDiversityLogitsProcessor,
34
    InfNanRemoveLogitsProcessor,
35
    LogitNormalization,
36
37
38
39
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    NoBadWordsLogitsProcessor,
    NoRepeatNGramLogitsProcessor,
40
    PrefixConstrainedLogitsProcessor,
41
42
43
44
    RepetitionPenaltyLogitsProcessor,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
45
    TypicalLogitsWarper,
46
)
47
48
49
from .generation_stopping_criteria import (
    MaxLengthCriteria,
    MaxTimeCriteria,
50
    StoppingCriteria,
51
52
53
    StoppingCriteriaList,
    validate_stopping_criteria,
)
54
from .pytorch_utils import torch_int_div
55
from .utils import ModelOutput, logging
56

Lysandre Debut's avatar
Lysandre Debut committed
57
58

logger = logging.get_logger(__name__)
59
60


61
62
63
64
65
66
67
@dataclass
class GreedySearchDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using greedy search.


    Args:
68
        sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
69
70
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
71
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
72
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
73
74
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
75
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
76
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
77
78
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
79
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
80
            `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class GreedySearchEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention
    weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
    encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)


    Args:
98
        sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
99
100
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
101
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
102
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
103
104
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
105
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
106
107
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
            sequence_length, sequence_length)`.
108
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
109
110
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
111
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
112
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
113
114
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
115
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
116
117
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
118
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
119
            `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
120
121
122
123
124
125
126
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
127
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
128
129
130
131
132
133
134
135
136
137
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class SampleDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using sampling.


    Args:
138
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
139
140
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
141
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
142
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
143
144
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`.
145
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
146
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
Sylvain Gugger's avatar
Sylvain Gugger committed
147
148
            `torch.FloatTensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length,
            sequence_length)`.
149
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
150
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
151
            `torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`.
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class SampleEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of
    the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
    attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)


    Args:
169
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
170
171
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
172
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
173
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
174
175
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`.
176
177
178
179
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape
            `(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`.
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
180
181
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size*num_return_sequences, sequence_length, hidden_size)`.
182
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
183
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
Sylvain Gugger's avatar
Sylvain Gugger committed
184
185
            `torch.FloatTensor` of shape `(batch_size*num_return_sequences, num_heads, generated_length,
            sequence_length)`.
186
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
187
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
188
189
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
190
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
191
            `torch.FloatTensor` of shape `(batch_size*num_return_sequences, generated_length, hidden_size)`.
192
193
194
195
196
197
198
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
199
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
200
201
202
203
204
205
206
207
208
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSearchDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using beam search.

    Args:
209
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
210
211
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
212
213
214
        sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
215
216
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
217
218
            Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
            with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`.
219
        beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
220
221
            Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
            `(batch_size*num_return_sequences, input_ids.shape[-1])`.
222
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
223
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
224
225
            `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
226
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
227
            `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
228
229
230
231
232
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
233
    beam_indices: Optional[torch.LongTensor] = None
234
235
236
237
238
239
240
241
242
243
244
245
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSearchEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights
    of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
    attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)

    Args:
246
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
247
248
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
249
250
251
        sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
252
253
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
254
255
            Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
            with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
256
        beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
257
258
            Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
            `(batch_size*num_return_sequences, max_length-1)`.
259
260
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
261
262
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
            sequence_length, sequence_length)`.
263
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
264
265
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
266
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
267
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
Sylvain Gugger's avatar
Sylvain Gugger committed
268
269
            `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
            sequence_length)`.
270
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
271
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
272
273
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
274
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
275
            `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
276
277
278
279
280
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
281
    beam_indices: Optional[torch.LongTensor] = None
282
283
284
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
285
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
286
287
288
289
290
291
292
293
294
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSampleDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using beam sample.

    Args:
295
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
296
297
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
298
299
300
        sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
301
302
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
303
304
            Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
            with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`.
305
        beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
306
307
            Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
            `(batch_size*num_return_sequences, input_ids.shape[-1])`.
308
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
309
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
310
311
            `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
312
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
313
            `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
314
315
316
317
318
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
319
    beam_indices: Optional[torch.LongTensor] = None
320
321
322
323
324
325
326
327
328
329
330
331
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSampleEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention
    weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
    encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)

    Args:
332
        sequences (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
333
334
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
335
336
337
        sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
338
339
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
340
341
            Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
            with each tensor of shape `(batch_size*num_beams, config.vocab_size)`).
342
343
344
        beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
            `(batch_size*num_return_sequences, max_length-1)`.
345
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
346
347
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
            sequence_length, sequence_length)`.
348
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
349
350
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size*num_beams, sequence_length, hidden_size)`.
351
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
352
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
353
354
            `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
355
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
356
357
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
358
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
359
            `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
360
361
362
363
364
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
365
    beam_indices: Optional[torch.LongTensor] = None
366
367
368
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
369
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
370
371
372
373
374
375
376
377
378
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]


379
380
class GenerationMixin:
    """
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].

    The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for:
        - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
          `do_sample=False`.
        - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
          `do_sample=True`.
        - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
          `do_sample=False`.
        - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
          `num_beams>1` and `do_sample=True`.
        - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
          `num_beams>1` and `num_beam_groups>1`.
        - *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`],
          if `constraints!=None` or `force_words_ids!=None`.
396
397
    """

398
399
400
401
402
    def _prepare_model_inputs(
        self,
        inputs: Optional[torch.Tensor] = None,
        bos_token_id: Optional[int] = None,
        model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
403
    ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
404
405
406
        """
        This function extracts the model-specific `inputs` for generation.
        """
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        # 1. retrieve all kwargs that are non-None or non-model input related.
        # some encoder-decoder models have different names for model and encoder
        if (
            self.config.is_encoder_decoder
            and hasattr(self, "encoder")
            and self.encoder.main_input_name != self.main_input_name
        ):
            input_name = self.encoder.main_input_name
        else:
            input_name = self.main_input_name

        model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}

        # 2. check whether model_input_name is passed as kwarg
        # if yes and `inputs` is None use kwarg inputs
        inputs_kwarg = model_kwargs.pop(input_name, None)
        if inputs_kwarg is not None and inputs is not None:
424
425
            raise ValueError(
                f"`inputs`: {inputs}` were passed alongside "
426
427
                f"{input_name} which is not allowed."
                f"Make sure to either pass {inputs} or {input_name}=..."
428
            )
429
430
431
432
433
434
435
436
437
        elif inputs_kwarg is not None:
            inputs = inputs_kwarg

        # 3. models with `input_ids` can also make use of `inputs_embeds`
        if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs):
            inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"

        # 4. Only encoder-decoder models can have non `input_ids` input format
        if not self.config.is_encoder_decoder and input_name != "input_ids":
438
            raise ValueError(
439
440
441
                f"If {input_name} is passed as model-specific keyword "
                "input then model has to be an encoder-decoder and not a "
                f"{self.__class__.__name__}."
442
443
            )

444
445
446
447
448
449
450
451
452
453
        # 5. if `inputs` is still None, try to create `input_ids` from BOS token
        if inputs is None:
            inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))

        return inputs, input_name, model_kwargs

    def _can_retrieve_inputs_from_name(
        self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor]
    ) -> torch.Tensor:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
454
455
        If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved
        from name
456
457
458
459
460
461
462
463
464
465
        """
        can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set(
            inspect.signature(self.forward).parameters.keys()
        )

        if can_retrieve_inputs and inputs is not None:
            raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}")

        return can_retrieve_inputs

466
    def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
Sylvain Gugger's avatar
Sylvain Gugger committed
467
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
468
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
Sylvain Gugger's avatar
Sylvain Gugger committed
469
        """
470
471
        return {"input_ids": input_ids}

472
    def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
Sylvain Gugger's avatar
Sylvain Gugger committed
473
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
474
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
Sylvain Gugger's avatar
Sylvain Gugger committed
475
        """
476
477
        return logits

478
479
480
481
482
483
484
485
    def _prepare_input_ids_for_generation(
        self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput]
    ) -> torch.LongTensor:
        if self.config.is_encoder_decoder and encoder_outputs is not None:
            # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
            shape = encoder_outputs.last_hidden_state.size()[:-1]
            return torch.ones(shape, dtype=torch.long, device=self.device) * -100

486
487
488
489
490
        if bos_token_id is None:
            raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
        return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id

    def _prepare_attention_mask_for_generation(
491
        self,
492
        inputs: torch.Tensor,
493
494
        pad_token_id: Optional[int],
        eos_token_id: Optional[int],
495
    ) -> torch.LongTensor:
496
        is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
497
        is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
498
499
        is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id)

500
501
502
        # Check if input is input_ids and padded -> only then is attention_mask defined
        if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
            return inputs.ne(pad_token_id).long()
503
        else:
504
            return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
505
506

    def _prepare_encoder_decoder_kwargs_for_generation(
507
        self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
508
    ) -> Dict[str, Any]:
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
        # 1. get encoder
        encoder = self.get_encoder()

        # 2. prepare encoder args and encoder kwargs from model kwargs
        irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
        encoder_kwargs = {
            argument: value
            for argument, value in model_kwargs.items()
            if not any(argument.startswith(p) for p in irrelevant_prefix)
        }

        # 3. make sure that encoder returns `ModelOutput`
        model_input_name = model_input_name if model_input_name is not None else self.main_input_name
        encoder_kwargs["return_dict"] = True
        encoder_kwargs[model_input_name] = inputs_tensor
        model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
525

526
527
528
        return model_kwargs

    def _prepare_decoder_input_ids_for_generation(
529
530
531
532
533
        self,
        batch_size: int,
        decoder_start_token_id: int = None,
        bos_token_id: int = None,
        model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
534
        device: torch.device = None,
535
    ) -> torch.LongTensor:
536

537
538
539
540
        if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
            return model_kwargs.pop("decoder_input_ids")
        else:
            decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
541
542
543
            if device is None:
                device = self.device
            return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575

    def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
        decoder_start_token_id = (
            decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
        )
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id

        if decoder_start_token_id is not None:
            return decoder_start_token_id
        elif (
            hasattr(self.config, "decoder")
            and hasattr(self.config.decoder, "decoder_start_token_id")
            and self.config.decoder.decoder_start_token_id is not None
        ):
            return self.config.decoder.decoder_start_token_id
        elif bos_token_id is not None:
            return bos_token_id
        elif (
            hasattr(self.config, "decoder")
            and hasattr(self.config.decoder, "bos_token_id")
            and self.config.decoder.bos_token_id is not None
        ):
            return self.config.decoder.bos_token_id
        raise ValueError(
            "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
        )

    @staticmethod
    def _expand_inputs_for_generation(
        input_ids: torch.LongTensor,
        expand_size: int = 1,
        is_encoder_decoder: bool = False,
576
577
        attention_mask: Optional[torch.LongTensor] = None,
        encoder_outputs: Optional[ModelOutput] = None,
578
        **model_kwargs,
579
580
581
582
583
584
    ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
        expanded_return_idx = (
            torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
        )
        input_ids = input_ids.index_select(0, expanded_return_idx)

585
586
587
588
        if "token_type_ids" in model_kwargs:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)

589
590
591
592
        if attention_mask is not None:
            model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)

        if is_encoder_decoder:
593
594
            if encoder_outputs is None:
                raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
595
            encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
596
                0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
            )
            model_kwargs["encoder_outputs"] = encoder_outputs
        return input_ids, model_kwargs

    @staticmethod
    def _update_model_kwargs_for_generation(
        outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
    ) -> Dict[str, Any]:
        # update past
        if "past_key_values" in outputs:
            model_kwargs["past"] = outputs.past_key_values
        elif "mems" in outputs:
            model_kwargs["past"] = outputs.mems
        elif "past_buckets_states" in outputs:
            model_kwargs["past"] = outputs.past_buckets_states
        else:
            model_kwargs["past"] = None

615
616
617
618
619
        # update token_type_ids with last value
        if "token_type_ids" in model_kwargs:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)

620
621
622
623
624
625
626
627
628
629
        # update attention mask
        if not is_encoder_decoder:
            if "attention_mask" in model_kwargs:
                attention_mask = model_kwargs["attention_mask"]
                model_kwargs["attention_mask"] = torch.cat(
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
                )

        return model_kwargs

630
631
    def _reorder_cache(self, past, beam_idx):
        raise NotImplementedError(
Sylvain Gugger's avatar
Sylvain Gugger committed
632
633
            f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to"
            f" enable beam search for {self.__class__}"
634
        )
635

636
    def _get_logits_warper(
637
        self,
638
639
640
641
642
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        typical_p: Optional[float] = None,
        temperature: Optional[float] = None,
        num_beams: Optional[int] = None,
643
        renormalize_logits: Optional[bool] = None,
644
645
    ) -> LogitsProcessorList:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
646
647
        This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
        used for multinomial sampling.
648
        """
649

650
651
652
        # init warp parameters
        top_k = top_k if top_k is not None else self.config.top_k
        top_p = top_p if top_p is not None else self.config.top_p
653
        typical_p = typical_p if typical_p is not None else self.config.typical_p
654
655
656
657
658
659
        temperature = temperature if temperature is not None else self.config.temperature
        # instantiate warpers list
        warpers = LogitsProcessorList()

        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
660
661
        if temperature is not None and temperature != 1.0:
            warpers.append(TemperatureLogitsWarper(temperature))
662
663
664
665
        if top_k is not None and top_k != 0:
            warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
        if top_p is not None and top_p < 1.0:
            warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
666
667
        if typical_p is not None and typical_p < 1.0:
            warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
668
669
670
        # `LogitNormalization` should always be the last logit processor, when present
        if renormalize_logits is True:
            warpers.append(LogitNormalization())
671
672
673
674
675
676
        return warpers

    def _get_logits_processor(
        self,
        repetition_penalty: float,
        no_repeat_ngram_size: int,
677
        encoder_no_repeat_ngram_size: int,
678
        input_ids_seq_length: int,
679
        encoder_input_ids: torch.LongTensor,
680
681
        bad_words_ids: List[List[int]],
        min_length: int,
682
        max_length: int,
683
        eos_token_id: int,
684
685
        forced_bos_token_id: int,
        forced_eos_token_id: int,
686
687
        prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
        num_beams: int,
688
689
        num_beam_groups: int,
        diversity_penalty: float,
690
        remove_invalid_values: bool,
691
        exponential_decay_length_penalty: Tuple,
692
        logits_processor: Optional[LogitsProcessorList],
693
        renormalize_logits: Optional[bool],
694
695
    ) -> LogitsProcessorList:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
696
697
        This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
        instances used to modify the scores of the language model head.
698
        """
699
        processors = LogitsProcessorList()
700

701
702
703
704
705
        # init warp parameters
        repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
        no_repeat_ngram_size = (
            no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
        )
706
707
708
709
710
        encoder_no_repeat_ngram_size = (
            encoder_no_repeat_ngram_size
            if encoder_no_repeat_ngram_size is not None
            else self.config.encoder_no_repeat_ngram_size
        )
711
712
        bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
713
        diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
714
715
716
717
718
719
        forced_bos_token_id = (
            forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
        )
        forced_eos_token_id = (
            forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
        )
720
721
722
        remove_invalid_values = (
            remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
        )
723
724
725
726
727
        exponential_decay_length_penalty = (
            exponential_decay_length_penalty
            if exponential_decay_length_penalty is not None
            else self.config.exponential_decay_length_penalty
        )
728
729
730
731
        # instantiate processors list

        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
732
733
734
735
736
737
        if diversity_penalty is not None and diversity_penalty > 0.0:
            processors.append(
                HammingDiversityLogitsProcessor(
                    diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups
                )
            )
738
739
740
741
        if repetition_penalty is not None and repetition_penalty != 1.0:
            processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
        if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
            processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
742
743
744
745
746
747
748
        if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0:
            if self.config.is_encoder_decoder:
                processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids))
            else:
                raise ValueError(
                    "It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture"
                )
749
        if bad_words_ids is not None:
750
            processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
751
        if min_length is not None and eos_token_id is not None and min_length > 0:
752
            processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
753
        if prefix_allowed_tokens_fn is not None:
754
            processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups))
755
756
757
758
        if forced_bos_token_id is not None:
            processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
        if forced_eos_token_id is not None:
            processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
759
760
        if remove_invalid_values is True:
            processors.append(InfNanRemoveLogitsProcessor())
761
762
763
764
        if exponential_decay_length_penalty is not None:
            processors.append(
                ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
            )
765
        processors = self._merge_criteria_processor_list(processors, logits_processor)
766
767
768
        # `LogitNormalization` should always be the last logit processor, when present
        if renormalize_logits is True:
            processors.append(LogitNormalization())
769
        return processors
770

771
772
773
774
    def _get_stopping_criteria(
        self, max_length: Optional[int], max_time: Optional[float], stopping_criteria: Optional[StoppingCriteriaList]
    ) -> StoppingCriteriaList:
        criteria = StoppingCriteriaList()
775
        if max_length is not None:
776
            criteria.append(MaxLengthCriteria(max_length=max_length))
777
        if max_time is not None:
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
            criteria.append(MaxTimeCriteria(max_time=max_time))
        criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
        return criteria

    def _merge_criteria_processor_list(
        self,
        default_list: Union[LogitsProcessorList, StoppingCriteriaList],
        custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
    ) -> Union[LogitsProcessorList, StoppingCriteriaList]:
        if len(custom_list) == 0:
            return default_list
        for default in default_list:
            for custom in custom_list:
                if type(custom) is type(default):
                    object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
                    raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
794
795
796
797
798
                        f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
                        f" `generate`, but it has already been created with the values {default}. {default} has been"
                        " created by passing the corresponding arguments to generate or by the model's config default"
                        f" values. If you just want to change the default values of {object_type} consider passing"
                        f" them as arguments to `generate` instead of using a custom {object_type}."
799
800
801
                    )
        default_list.extend(custom_list)
        return default_list
802

803
804
805
806
807
808
809
810
811
812
    def compute_transition_beam_scores(
        self,
        sequences: torch.Tensor,
        scores: Tuple[torch.Tensor],
        beam_indices: torch.Tensor,
        eos_token_id: int = None,
    ):
        """compute the transition probabilities of sequences given generation
        scores and beam indices"""

813
        # 1. reshape scores as [vocab_size * batch_size, # generation steps]
814
815
816
817
        # with batch_size being 2 * vocab_size and # generation steps being
        # seq_len - input_length
        scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)

818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
        # 2. cut beam_indices to longest beam length
        beam_indices_mask = beam_indices < 0
        max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max()
        beam_indices = beam_indices[:, :max_beam_length]
        beam_indices_mask = beam_indices_mask[:, :max_beam_length]

        # 3. Set indices of beams that finished early to 0
        # such indices will be masked correctly afterwards
        beam_indices[beam_indices_mask] = 0

        # 4. multiply beam_indices with vocab size to gather correctly from scores
        beam_sequence_indices = beam_indices * self.config.vocab_size

        # 5. Define which indices contributed to scores
        cut_idx = sequences.shape[-1] - max_beam_length
833
        indices = sequences[:, cut_idx:] + beam_sequence_indices
834
835

        # 6. Compute scores
836
837
        transition_scores = scores.gather(0, indices)

838
839
        # 7. Mask out transition_scores of beams that stopped early
        transition_scores[beam_indices_mask] = 0
840
841
842

        return transition_scores

843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
    def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
        """Validates model kwargs for generation. Generate argument typos will also be caught here."""
        # Excludes arguments that are handled before calling any model function
        if self.config.is_encoder_decoder:
            for key in ["decoder_input_ids"]:
                model_kwargs.pop(key, None)

        unused_model_args = []
        model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
        # `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
        # `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
        if "kwargs" in model_args:
            model_args |= set(inspect.signature(self.forward).parameters)
        for key, value in model_kwargs.items():
            if value is not None and key not in model_args:
                unused_model_args.append(key)

        if unused_model_args:
            raise ValueError(
                f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
                " generate arguments will also show up in this list)"
            )

866
867
868
    @torch.no_grad()
    def generate(
        self,
869
        inputs: Optional[torch.Tensor] = None,
870
871
872
873
874
875
876
877
        max_length: Optional[int] = None,
        min_length: Optional[int] = None,
        do_sample: Optional[bool] = None,
        early_stopping: Optional[bool] = None,
        num_beams: Optional[int] = None,
        temperature: Optional[float] = None,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
878
        typical_p: Optional[float] = None,
879
880
        repetition_penalty: Optional[float] = None,
        bad_words_ids: Optional[Iterable[int]] = None,
881
        force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
882
883
884
885
886
        bos_token_id: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        no_repeat_ngram_size: Optional[int] = None,
887
        encoder_no_repeat_ngram_size: Optional[int] = None,
888
        num_return_sequences: Optional[int] = None,
889
        max_time: Optional[float] = None,
890
        max_new_tokens: Optional[int] = None,
891
892
        decoder_start_token_id: Optional[int] = None,
        use_cache: Optional[bool] = None,
893
894
        num_beam_groups: Optional[int] = None,
        diversity_penalty: Optional[float] = None,
895
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
896
        logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
897
        renormalize_logits: Optional[bool] = None,
898
        stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
899
        constraints: Optional[List[Constraint]] = None,
900
901
902
903
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
904
905
        forced_bos_token_id: Optional[int] = None,
        forced_eos_token_id: Optional[int] = None,
906
        remove_invalid_values: Optional[bool] = None,
907
        synced_gpus: Optional[bool] = False,
908
        exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
909
910
        **model_kwargs,
    ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
Sylvain Gugger's avatar
Sylvain Gugger committed
911
        r"""
912

913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
        Generates sequences of token ids for models with a language modeling head. The method supports the following
        generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:

            - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
              `do_sample=False`.
            - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
              `do_sample=True`.
            - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
              `do_sample=False`.
            - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
              `num_beams>1` and `do_sample=True`.
            - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
              `num_beams>1` and `num_beam_groups>1`.
            - *constrained beam-search decoding* by calling
              [`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or
              `force_words_ids!=None`.

        <Tip warning={true}>

        Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
        defined in the model's config (`config.json`) which in turn defaults to the
        [`~modeling_utils.PretrainedConfig`] of the model.

        </Tip>
937

Sylvain Gugger's avatar
Sylvain Gugger committed
938
939
        Most of these parameters are explained in more detail in [this blog
        post](https://huggingface.co/blog/how-to-generate).
940
941

        Parameters:
942
            inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
943
                The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
Sylvain Gugger's avatar
Sylvain Gugger committed
944
945
946
                method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
                should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
                `input_ids`, `input_values`, `input_features`, or `pixel_values`.
947
            max_length (`int`, *optional*, defaults to `model.config.max_length`):
948
949
950
951
952
                The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
                `max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in
                the prompt.
            max_new_tokens (`int`, *optional*):
                The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
953
            min_length (`int`, *optional*, defaults to `model.config.min_length` or 10 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
954
                The minimum length of the sequence to be generated.
955
            do_sample (`bool`, *optional*, defaults to `model.config.do_sample` or `False` if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
956
                Whether or not to use sampling ; use greedy decoding otherwise.
957
958
            early_stopping (`bool`, *optional*, defaults to `False`):
                Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
959
            num_beams (`int`, *optional*, defaults to `model.config.num_beams` or 1 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
960
                Number of beams for beam search. 1 means no beam search.
961
            temperature (`float`, *optional*, defaults to `model.config.temperature` or 1.0 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
962
                The value used to module the next token probabilities.
963
            top_k (`int`, *optional*, defaults to `model.config.top_k` or 50 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
964
                The number of highest probability vocabulary tokens to keep for top-k-filtering.
965
            top_p (`float`, *optional*, defaults to `model.config.top_p` or 1.0 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
966
967
                If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
                are kept for generation.
968
            typical_p (`float`, *optional*, defaults to `model.config.typical_p` or 1.0 if the config does not set any value):
Yih-Dar's avatar
Yih-Dar committed
969
970
                The amount of probability mass from the original distribution to be considered in typical decoding. If
                set to 1.0 it takes no effect. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
971
            repetition_penalty (`float`, *optional*, defaults to `model.config.repetition_penalty` or 1.0 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
972
973
                The parameter for repetition penalty. 1.0 means no penalty. See [this
                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
974
            pad_token_id (`int`, *optional*, defaults to `model.config.pad_token_id`):
975
                The id of the *padding* token.
976
            bos_token_id (`int`, *optional*, defaults to `model.config.bos_token_id`):
977
                The id of the *beginning-of-sequence* token.
978
            eos_token_id (`int`, *optional*, defaults to `model.config.eos_token_id`):
979
                The id of the *end-of-sequence* token.
980
            length_penalty (`float`, *optional*, defaults to `model.config.length_penalty` or 1.0 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
981
982
983
                 Exponential penalty to the length. 1.0 means that the beam score is penalized by the sequence length.
                 0.0 means no penalty. Set to values < 0.0 in order to encourage the model to generate longer
                 sequences, to a value > 0.0 in order to encourage the model to produce shorter sequences.
984
            no_repeat_ngram_size (`int`, *optional*, defaults to `model.config.no_repeat_ngram_size` or 0 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
985
                If set to int > 0, all ngrams of that size can only occur once.
986
            encoder_no_repeat_ngram_size (`int`, *optional*, defaults to `model.config.encoder_no_repeat_ngram_size` or 0 if the config does not set any value):
987
988
                If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
                `decoder_input_ids`.
989
            bad_words_ids(`List[List[int]]`, *optional*, defaults to `model.config.bad_words_ids`):
990
991
                List of token ids that are not allowed to be generated. In order to get the token ids of the words that
                should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
992
                add_special_tokens=False).input_ids`.
993
994
995
996
997
            force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
                List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple
                list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`,
                this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081),
                where one can allow different forms of each word.
998
            num_return_sequences(`int`, *optional*, defaults to `model.config.num_return_sequences` or 1 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
999
                The number of independently computed returned sequences for each element in the batch.
1000
            max_time(`float`, *optional*):
1001
1002
                The maximum amount of time you allow the computation to run for in seconds. generation will still
                finish the current pass after allocated time has been passed.
1003
            attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1004
1005
1006
                Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens
                that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape
                as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask)
1007
1008
1009
            decoder_start_token_id (`int`, *optional*):
                If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
            use_cache: (`bool`, *optional*, defaults to `True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1010
1011
                Whether or not the model should use the past last key/values attentions (if applicable to the model) to
                speed up decoding.
1012
            num_beam_groups (`int`, *optional*, defaults to `model.config.num_beam_groups` or 1 if the config does not set any value):
1013
1014
                Number of groups to divide `num_beams` into in order to ensure diversity among different groups of
                beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
1015
            diversity_penalty (`float`, *optional*, defaults to `model.config.diversity_penalty` or 0.0 if the config does not set any value):
1016
                This value is subtracted from a beam's score if it generates a token same as any beam from other group
1017
                at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is
1018
                enabled.
1019
            prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
1020
                If provided, this function constraints the beam search to allowed tokens only at each step. If not
1021
                provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
Sylvain Gugger's avatar
Sylvain Gugger committed
1022
1023
1024
1025
                `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
                on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
                for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
                Retrieval](https://arxiv.org/abs/2010.00904).
1026
            logits_processor (`LogitsProcessorList`, *optional*):
1027
1028
1029
                 Custom logits processors that complement the default logits processors built from arguments and a
                 model's config. If a logit processor is passed that is already created with the arguments or a model's
                 config an error is thrown. This feature is intended for advanced users.
1030
1031
1032
1033
            renormalize_logits: (`bool`, *optional*, defaults to `False`):
                Whether to renormalize the logits after applying all the logits processors or warpers (including the
                custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
                score logits are normalized but some logit processors or warpers break the normalization.
1034
            stopping_criteria (`StoppingCriteriaList`, *optional*):
1035
1036
1037
                 Custom stopping criteria that complement the default stopping criteria built from arguments and a
                 model's config. If a stopping criteria is passed that is already created with the arguments or a
                 model's config an error is thrown. This feature is intended for advanced users.
1038
1039
1040
            constraints (`List[Constraint]`, *optional*):
                 Custom constraints that can be added to the generation to ensure that the output will contain the use
                 of certain tokens as defined by `Constraint` objects, in the most sensible way possible.
1041
            output_attentions (`bool`, *optional*, defaults to `model.config.output_attentions` or `False` if the config does not set any value):
1042
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1043
                returned tensors for more details.
1044
            output_hidden_states (`bool`, *optional*, defaults to `model.config.output_hidden_states` or `False` if the config does not set any value):
1045
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1046
                for more details.
1047
            output_scores (`bool`, *optional*, defaults to `model.config.output_scores` or `False` if the config does not set any value):
1048
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
1049
            return_dict_in_generate (`bool`, *optional*, defaults to `model.config.return_dict_in_generate` or `False` if the config does not set any value):
1050
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1051
            forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1052
1053
1054
                The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful
                for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be
                the target language token.
1055
            forced_eos_token_id (`int`, *optional*, defaults to `model.config.forced_eos_token_id`):
1056
                The id of the token to force as the last generated token when `max_length` is reached.
1057
            remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`):
1058
1059
1060
                Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to
                crash. Note that using `remove_invalid_values` can slow down generation.
            synced_gpus (`bool`, *optional*, defaults to `False`):
1061
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
1062
            exponential_decay_length_penalty (`tuple(int, float)`, *optional*, defaults to `model.config.exponential_decay_length_penalty`):
1063
1064
1065
1066
                This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
                generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates
                where penalty starts and `decay_factor` represents the factor of exponential decay

1067
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
1068
1069
1070
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
                is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
                should be prefixed with *decoder_*.
1071
1072

        Return:
1073
1074
            [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
            or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
1075

Sylvain Gugger's avatar
Sylvain Gugger committed
1076
                If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
1077
                [`~utils.ModelOutput`] types are:
1078
1079
1080
1081
1082
1083
1084

                    - [`~generation_utils.GreedySearchDecoderOnlyOutput`],
                    - [`~generation_utils.SampleDecoderOnlyOutput`],
                    - [`~generation_utils.BeamSearchDecoderOnlyOutput`],
                    - [`~generation_utils.BeamSampleDecoderOnlyOutput`]

                If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
1085
                [`~utils.ModelOutput`] types are:
1086
1087
1088
1089
1090
1091
1092
1093

                    - [`~generation_utils.GreedySearchEncoderDecoderOutput`],
                    - [`~generation_utils.SampleEncoderDecoderOutput`],
                    - [`~generation_utils.BeamSearchEncoderDecoderOutput`],
                    - [`~generation_utils.BeamSampleEncoderDecoderOutput`]

        Examples:

1094
1095
        Greedy Decoding:

1096
        ```python
1097
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM
1098

1099
1100
        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
1101

1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
        >>> prompt = "Today I believe we can finally"
        >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids

        >>> # generate up to 30 tokens
        >>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
        ```

        Multinomial Sampling:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
1118
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140

        >>> prompt = "Today I believe we can finally"
        >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids

        >>> # sample up to 30 tokens
        >>> torch.manual_seed(0)  # doctest: +IGNORE_RESULT
        >>> outputs = model.generate(input_ids, do_sample=True, max_length=30)
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the']
        ```

        Beam-search decoding:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

        >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")

        >>> sentence = "Paris is one of the densest populated areas in Europe."
        >>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids

1141
        >>> outputs = model.generate(input_ids, num_beams=5)
1142
1143
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
1144
        ```"""
1145
1146
1147
        # 0. Validate model kwargs
        self._validate_model_kwargs(model_kwargs.copy())

1148
1149
        # 1. Set generation parameters if not already defined
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
1150
        num_beams = num_beams if num_beams is not None else self.config.num_beams
1151
1152
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
1153
        num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
1154
1155
1156
1157
1158
        do_sample = do_sample if do_sample is not None else self.config.do_sample
        num_return_sequences = (
            num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
        )

1159
1160
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1161

1162
1163
        if eos_token_id is None and hasattr(self.config, "decoder"):
            eos_token_id = self.config.decoder.eos_token_id
1164

1165
        if pad_token_id is None and eos_token_id is not None:
1166
1167
1168
1169
1170
            if model_kwargs.get("attention_mask", None) is None:
                logger.warning(
                    "The attention mask and the pad token id were not set. As a consequence, you may observe "
                    "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
                )
1171
1172
1173
            logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
            pad_token_id = eos_token_id

1174
1175
1176
1177
1178
1179
1180
1181
1182
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

1183
1184
1185
1186
1187
1188
1189
1190
1191
        # 2. Define model inputs
        # inputs_tensor has to be defined
        # model_input_name is defined if model-specific keyword input is passed
        # otherwise model_input_name is None
        # all model-specific keyword inputs are removed from `model_kwargs`
        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs)
        batch_size = inputs_tensor.shape[0]

        # 3. Define other model kwargs
1192
1193
        model_kwargs["output_attentions"] = output_attentions
        model_kwargs["output_hidden_states"] = output_hidden_states
1194
        model_kwargs["use_cache"] = use_cache
1195

1196
1197
1198
1199
        accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
        requires_attention_mask = "encoder_outputs" not in model_kwargs

        if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
1200
            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
1201
                inputs_tensor, pad_token_id, eos_token_id
1202
            )
1203

1204
        if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
1205
1206
1207
1208
1209
            # if model is encoder decoder encoder_outputs are created
            # and added to `model_kwargs`
            model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
                inputs_tensor, model_kwargs, model_input_name
            )
1210

1211
1212
1213
1214
1215
1216
1217
        # 4. Prepare `input_ids` which will be used for auto-regressive generation
        if self.config.is_encoder_decoder:
            input_ids = self._prepare_decoder_input_ids_for_generation(
                batch_size,
                decoder_start_token_id=decoder_start_token_id,
                bos_token_id=bos_token_id,
                model_kwargs=model_kwargs,
1218
                device=inputs_tensor.device,
1219
            )
1220
        else:
1221
1222
            # if decoder-only then inputs_tensor has to be `input_ids`
            input_ids = inputs_tensor
1223

1224
        # 5. Prepare `max_length` depending on other stopping criteria.
1225
        input_ids_seq_length = input_ids.shape[-1]
1226
        if max_length is None and max_new_tokens is None:
1227
            warnings.warn(
Stas Bekman's avatar
Stas Bekman committed
1228
                "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to "
1229
1230
1231
                f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is "
                "deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend "
                "using `max_new_tokens` to control the maximum length of the generation.",
1232
1233
                UserWarning,
            )
1234
1235
1236
1237
1238
1239
1240
1241
1242
        elif max_length is None and max_new_tokens is not None:
            max_length = max_new_tokens + input_ids_seq_length
        elif max_length is not None and max_new_tokens is not None:
            raise ValueError(
                "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
                " limit to the generated output length. Remove one of those arguments. Please refer to the"
                " documentation for more information. "
                "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
            )
1243
1244
        # default to config if still None
        max_length = max_length if max_length is not None else self.config.max_length
1245
        min_length = min_length if min_length is not None else self.config.min_length
1246

1247
1248
1249
1250
1251
        if min_length is not None and min_length > max_length:
            raise ValueError(
                f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
                f"length ({max_length})"
            )
1252
        if input_ids_seq_length >= max_length:
1253
1254
            input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
            logger.warning(
1255
1256
1257
                f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
                f" {max_length}. This can lead to unexpected behavior. You should consider increasing "
                "`max_new_tokens`."
1258
1259
            )

1260
        # 6. determine generation mode
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
        is_constraint_gen_mode = constraints is not None or force_words_ids is not None
        is_greedy_gen_mode = (
            (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
        )
        is_sample_gen_mode = (
            (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
        )
        is_beam_gen_mode = (
            (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
        )
1271
        is_beam_sample_gen_mode = (
1272
            (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
1273
        )
1274
        is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode
1275

1276
1277
1278
1279
1280
1281
        if num_beam_groups > num_beams:
            raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
        if is_group_beam_gen_mode and do_sample is True:
            raise ValueError(
                "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
            )
1282

1283
        # 7. prepare distribution pre_processing samplers
1284
1285
1286
        logits_processor = self._get_logits_processor(
            repetition_penalty=repetition_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
1287
            encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
1288
            input_ids_seq_length=input_ids_seq_length,
1289
            encoder_input_ids=inputs_tensor,
1290
1291
            bad_words_ids=bad_words_ids,
            min_length=min_length,
1292
            max_length=max_length,
1293
            eos_token_id=eos_token_id,
1294
1295
            forced_bos_token_id=forced_bos_token_id,
            forced_eos_token_id=forced_eos_token_id,
1296
1297
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            num_beams=num_beams,
1298
1299
            num_beam_groups=num_beam_groups,
            diversity_penalty=diversity_penalty,
1300
            remove_invalid_values=remove_invalid_values,
1301
            exponential_decay_length_penalty=exponential_decay_length_penalty,
1302
            logits_processor=logits_processor,
1303
            renormalize_logits=renormalize_logits,
1304
        )
1305

1306
        # 8. prepare stopping criteria
1307
1308
1309
        stopping_criteria = self._get_stopping_criteria(
            max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria
        )
1310

1311
        # 9. go into different generation modes
1312
1313
1314
1315
        if is_greedy_gen_mode:
            if num_return_sequences > 1:
                raise ValueError(
                    f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
1316
                )
1317

1318
            # 10. run greedy search
1319
1320
1321
            return self.greedy_search(
                input_ids,
                logits_processor=logits_processor,
1322
                stopping_criteria=stopping_criteria,
1323
1324
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1325
1326
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1327
                synced_gpus=synced_gpus,
1328
                **model_kwargs,
1329
            )
1330

1331
        elif is_sample_gen_mode:
1332
            # 10. prepare logits warper
1333
            logits_warper = self._get_logits_warper(
1334
1335
1336
1337
1338
1339
                top_k=top_k,
                top_p=top_p,
                typical_p=typical_p,
                temperature=temperature,
                num_beams=num_beams,
                renormalize_logits=renormalize_logits,
1340
1341
            )

1342
            # 11. expand input_ids with `num_return_sequences` additional sequences per batch
1343
1344
1345
1346
1347
1348
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids,
                expand_size=num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )
1349

1350
            # 12. run sample
1351
1352
1353
1354
            return self.sample(
                input_ids,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
1355
                stopping_criteria=stopping_criteria,
1356
1357
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1358
1359
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1360
                synced_gpus=synced_gpus,
1361
1362
                **model_kwargs,
            )
1363

1364
1365
1366
        elif is_beam_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
1367

1368
1369
1370
            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

1371
            # 10. prepare beam search scorer
1372
1373
1374
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=num_beams,
1375
                device=inputs_tensor.device,
1376
1377
1378
1379
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
            )
1380
            # 11. interleave input_ids with `num_beams` additional sequences per batch
1381
1382
1383
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
1384
            # 12. run beam search
1385
            return self.beam_search(
1386
                input_ids,
1387
1388
                beam_scorer,
                logits_processor=logits_processor,
1389
                stopping_criteria=stopping_criteria,
1390
1391
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1392
1393
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1394
                synced_gpus=synced_gpus,
1395
1396
1397
1398
                **model_kwargs,
            )

        elif is_beam_sample_gen_mode:
1399
            # 10. prepare logits warper
1400
            logits_warper = self._get_logits_warper(
1401
1402
1403
1404
1405
1406
                top_k=top_k,
                top_p=top_p,
                typical_p=typical_p,
                temperature=temperature,
                num_beams=num_beams,
                renormalize_logits=renormalize_logits,
1407
1408
            )

1409
1410
            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")
1411
            # 11. prepare beam search scorer
1412
            beam_scorer = BeamSearchScorer(
1413
                batch_size=batch_size * num_return_sequences,
1414
                num_beams=num_beams,
1415
                device=inputs_tensor.device,
1416
1417
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
1418
            )
1419

1420
            # 12. interleave input_ids with `num_beams` additional sequences per batch
1421
1422
1423
1424
1425
1426
1427
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids,
                expand_size=num_beams * num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )

1428
            # 13. run beam sample
1429
            return self.beam_sample(
1430
                input_ids,
1431
1432
1433
                beam_scorer,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
1434
                stopping_criteria=stopping_criteria,
1435
1436
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1437
1438
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1439
                synced_gpus=synced_gpus,
1440
                **model_kwargs,
1441
1442
            )

1443
1444
1445
1446
1447
1448
1449
        elif is_group_beam_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

            if num_beams % num_beam_groups != 0:
                raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")

1450
1451
1452
            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

1453
1454
            # 10. prepare beam search scorer
            beam_scorer = BeamSearchScorer(
1455
1456
                batch_size=batch_size,
                num_beams=num_beams,
1457
                max_length=stopping_criteria.max_length,
1458
                device=inputs_tensor.device,
1459
1460
1461
1462
1463
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
                num_beam_groups=num_beam_groups,
            )
1464
            # 11. interleave input_ids with `num_beams` additional sequences per batch
1465
1466
1467
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
1468
            # 12. run beam search
1469
1470
            return self.group_beam_search(
                input_ids,
1471
                beam_scorer,
1472
                logits_processor=logits_processor,
1473
                stopping_criteria=stopping_criteria,
1474
1475
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1476
1477
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1478
                synced_gpus=synced_gpus,
1479
1480
1481
                **model_kwargs,
            )

1482
1483
1484
1485
1486
1487
1488
1489
        elif is_constraint_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

            if num_beams <= 1:
1490
                raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")
1491
1492
1493
1494
1495
1496
1497

            if do_sample:
                raise ValueError("`do_sample` needs to be false for constrained generation.")

            if num_beam_groups is not None and num_beam_groups > 1:
                raise ValueError("`num_beam_groups` not supported yet for constrained generation.")

1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
            final_constraints = []
            if constraints is not None:
                final_constraints = constraints

            if force_words_ids is not None:

                def typeerror():
                    raise ValueError(
                        "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
                        f"of positive integers, but is {force_words_ids}."
                    )

                if not isinstance(force_words_ids, list) or len(force_words_ids) == 0:
                    typeerror()

                for word_ids in force_words_ids:
                    if isinstance(word_ids[0], list):
                        if not isinstance(word_ids, list) or len(word_ids) == 0:
                            typeerror()
                        if any(not isinstance(token_ids, list) for token_ids in word_ids):
                            typeerror()
                        if any(
                            any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
                            for token_ids in word_ids
                        ):
                            typeerror()

                        constraint = DisjunctiveConstraint(word_ids)
                    else:
                        if not isinstance(word_ids, list) or len(word_ids) == 0:
                            typeerror()
                        if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
                            typeerror()

                        constraint = PhrasalConstraint(word_ids)
                    final_constraints.append(constraint)

1535
1536
            # 10. prepare beam search scorer
            constrained_beam_scorer = ConstrainedBeamSearchScorer(
1537
                constraints=final_constraints,
1538
1539
                batch_size=batch_size,
                num_beams=num_beams,
1540
                device=inputs_tensor.device,
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
            )
            # 11. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
            # 12. run beam search
            return self.constrained_beam_search(
                input_ids,
                constrained_beam_scorer=constrained_beam_scorer,
                logits_processor=logits_processor,
                stopping_criteria=stopping_criteria,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

1563
    def greedy_search(
1564
        self,
1565
1566
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
1567
        stopping_criteria: Optional[StoppingCriteriaList] = None,
1568
1569
1570
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
1571
1572
1573
1574
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
1575
        synced_gpus: Optional[bool] = False,
1576
1577
        **model_kwargs,
    ) -> Union[GreedySearchOutput, torch.LongTensor]:
1578
        r"""
1579
1580
        Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
        used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
1581
1582

        Parameters:
1583
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1584
                The sequence used as a prompt for the generation.
1585
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1586
1587
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
1588
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1589
1590
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
1591

1592
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
1593
1594
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
1595
1596
1597
1598
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
1599
            output_attentions (`bool`, *optional*, defaults to `False`):
1600
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1601
                returned tensors for more details.
1602
            output_hidden_states (`bool`, *optional*, defaults to `False`):
1603
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1604
                for more details.
1605
            output_scores (`bool`, *optional*, defaults to `False`):
1606
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
1607
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
1608
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1609
            synced_gpus (`bool`, *optional*, defaults to `False`):
1610
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
1611
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
1612
1613
                Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
                If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
1614
1615

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1616
            [`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`]
Stas Bekman's avatar
Stas Bekman committed
1617
            or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
1618
1619
            [`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if
1620
1621
1622
1623
1624
1625
            `model.config.is_encoder_decoder=True`.

        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
1626
1627
1628
1629
        ...     AutoTokenizer,
        ...     AutoModelForCausalLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
1630
1631
        ...     StoppingCriteriaList,
        ...     MaxLengthCriteria,
1632
1633
1634
1635
1636
        ... )

        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")

1637
        >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token
1638
1639
        >>> model.config.pad_token_id = model.config.eos_token_id

1640
        >>> input_prompt = "It might be possible to"
1641
1642
1643
        >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
1644
1645
        >>> logits_processor = LogitsProcessorList(
        ...     [
1646
        ...         MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id),
Sylvain Gugger's avatar
Sylvain Gugger committed
1647
1648
        ...     ]
        ... )
1649
        >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
1650

1651
1652
1653
        >>> outputs = model.greedy_search(
        ...     input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
        ... )
1654

1655
1656
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ["It might be possible to get a better understanding of the nature of the problem, but it's not"]
1657
        ```"""
1658
1659
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1660
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1661
1662
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
1663
1664
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
1665
1666
1667
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
1668
1669
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
1682
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
1683
1684
1685
1686
1687
1688
1689
1690
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
1691

1692
1693
1694
        # keep track of which sequences are already finished
        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
        cur_len = input_ids.shape[-1]
1695

1696
        this_peer_finished = False  # used by synced_gpus only
1697
        while True:
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

1709
1710
            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1711

1712
            # forward pass to get next token
1713
1714
1715
1716
1717
1718
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
1719
1720
1721
1722
1723

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

1724
            next_token_logits = outputs.logits[:, -1, :]
1725

1726
1727
1728
            # pre-process distribution
            next_tokens_scores = logits_processor(input_ids, next_token_logits)

1729
1730
1731
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
1732
                    scores += (next_tokens_scores,)
1733
1734
1735
1736
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
1737
1738
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
1739
1740
1741
1742
1743
1744
1745
1746

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

1747
            # argmax
1748
            next_tokens = torch.argmax(next_tokens_scores, dim=-1)
1749

1750
            # finished sentences should have their next token be a padding token
1751
            if eos_token_id is not None:
1752
1753
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
1754
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
1755

1756
            # update generated ids, model inputs, and length for next step
1757
1758
1759
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1760
            )
1761
1762
            cur_len = cur_len + 1

1763
1764
1765
1766
1767
            # if eos_token was found in one sentence, set sentence to finished
            if eos_token_id is not None:
                unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

            # stop when each sentence is finished, or if we exceed the maximum length
1768
1769
1770
1771
1772
1773
            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True

1774
1775
1776
1777
1778
1779
1780
1781
        if return_dict_in_generate:
            if self.config.is_encoder_decoder:
                return GreedySearchEncoderDecoderOutput(
                    sequences=input_ids,
                    scores=scores,
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
1782
                    cross_attentions=cross_attentions,
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return GreedySearchDecoderOnlyOutput(
                    sequences=input_ids,
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return input_ids
1794
1795
1796
1797
1798

    def sample(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
1799
        stopping_criteria: Optional[StoppingCriteriaList] = None,
1800
1801
1802
1803
        logits_warper: Optional[LogitsProcessorList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
1804
1805
1806
1807
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
1808
        synced_gpus: Optional[bool] = False,
1809
1810
        **model_kwargs,
    ) -> Union[SampleOutput, torch.LongTensor]:
1811
        r"""
1812
1813
        Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
1814
1815

        Parameters:
1816
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1817
                The sequence used as a prompt for the generation.
1818
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1819
1820
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
1821
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1822
1823
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
1824
            logits_warper (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1825
1826
1827
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
                to warp the prediction score distribution of the language modeling head applied before multinomial
                sampling at each generation step.
1828
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
1829
1830
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
1831
1832
1833
1834
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
1835
            output_attentions (`bool`, *optional*, defaults to `False`):
1836
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1837
                returned tensors for more details.
1838
            output_hidden_states (`bool`, *optional*, defaults to `False`):
1839
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1840
                for more details.
1841
            output_scores (`bool`, *optional*, defaults to `False`):
1842
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
1843
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
1844
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1845
            synced_gpus (`bool`, *optional*, defaults to `False`):
1846
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
1847
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
1848
1849
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.
1850
1851

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1852
            [`~generation_utils.SampleDecoderOnlyOutput`], [`~generation_utils.SampleEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
1853
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
1854
1855
            [`~generation_utils.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.SampleEncoderDecoderOutput`] if
1856
1857
1858
1859
1860
1861
            `model.config.is_encoder_decoder=True`.

        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
1862
1863
1864
1865
1866
1867
        ...     AutoTokenizer,
        ...     AutoModelForCausalLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     TopKLogitsWarper,
        ...     TemperatureLogitsWarper,
1868
1869
        ...     StoppingCriteriaList,
        ...     MaxLengthCriteria,
1870
        ... )
1871
        >>> import torch
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882

        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")

        >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
        >>> model.config.pad_token_id = model.config.eos_token_id

        >>> input_prompt = "Today is a beautiful day, and"
        >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
1883
1884
1885
1886
1887
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )
1888
        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
1889
1890
1891
1892
1893
1894
        >>> logits_warper = LogitsProcessorList(
        ...     [
        ...         TopKLogitsWarper(50),
        ...         TemperatureLogitsWarper(0.7),
        ...     ]
        ... )
1895

1896
1897
1898
1899
1900
1901
1902
1903
1904
        >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])

        >>> torch.manual_seed(0)  # doctest: +IGNORE_RESULT
        >>> outputs = model.sample(
        ...     input_ids,
        ...     logits_processor=logits_processor,
        ...     logits_warper=logits_warper,
        ...     stopping_criteria=stopping_criteria,
        ... )
1905

1906
1907
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
1908
        ```"""
1909
1910
1911

        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1912
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1913
1914
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
1915
1916
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
1917
1918
1919
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
1920
        logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
1921
1922
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
1935
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
1936
1937
1938
1939
1940
1941
1942
1943
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
1944

1945
1946
1947
        # keep track of which sequences are already finished
        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
        cur_len = input_ids.shape[-1]
1948

1949
        this_peer_finished = False  # used by synced_gpus only
1950
        # auto-regressive generation
1951
        while True:
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

1963
1964
1965
1966
            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # forward pass to get next token
1967
1968
1969
1970
1971
1972
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
1973
1974
1975
1976
1977

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

1978
1979
1980
            next_token_logits = outputs.logits[:, -1, :]

            # pre-process distribution
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_token_scores = logits_warper(input_ids, next_token_scores)

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_token_scores,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
1992
1993
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
1994
1995
1996
1997
1998
1999
2000

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )
2001
2002

            # sample
2003
            probs = nn.functional.softmax(next_token_scores, dim=-1)
2004
2005
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

2006
            # finished sentences should have their next token be a padding token
2007
            if eos_token_id is not None:
2008
2009
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
2010
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
2011

2012
            # update generated ids, model inputs, and length for next step
2013
2014
2015
2016
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
2017
2018
            cur_len = cur_len + 1

2019
2020
2021
2022
2023
            # if eos_token was found in one sentence, set sentence to finished
            if eos_token_id is not None:
                unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

            # stop when each sentence is finished, or if we exceed the maximum length
2024
2025
2026
2027
2028
2029
            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True

2030
2031
2032
2033
2034
2035
2036
2037
        if return_dict_in_generate:
            if self.config.is_encoder_decoder:
                return SampleEncoderDecoderOutput(
                    sequences=input_ids,
                    scores=scores,
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
2038
                    cross_attentions=cross_attentions,
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return SampleDecoderOnlyOutput(
                    sequences=input_ids,
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return input_ids
2050

2051
    def beam_search(
2052
        self,
2053
2054
2055
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
2056
        stopping_criteria: Optional[StoppingCriteriaList] = None,
2057
2058
2059
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
2060
2061
2062
2063
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
2064
        synced_gpus: Optional[bool] = False,
2065
2066
        **model_kwargs,
    ) -> Union[BeamSearchOutput, torch.LongTensor]:
2067
        r"""
2068
2069
        Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
2070

2071
        Parameters:
2072
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
2073
                The sequence used as a prompt for the generation.
2074
            beam_scorer (`BeamScorer`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2075
2076
                An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
2077
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2078
2079
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
2080
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2081
2082
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
2083
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
2084
2085
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
2086
2087
2088
2089
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
2090
            output_attentions (`bool`, *optional*, defaults to `False`):
2091
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2092
                returned tensors for more details.
2093
            output_hidden_states (`bool`, *optional*, defaults to `False`):
2094
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
2095
                for more details.
2096
            output_scores (`bool`, *optional*, defaults to `False`):
2097
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
2098
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
2099
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2100
            synced_gpus (`bool`, *optional*, defaults to `False`):
2101
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
2102
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
2103
2104
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.
2105

2106
        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2107
            [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
2108
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2109
2110
            [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if
2111
2112
2113
2114
2115
2116
2117
            `model.config.is_encoder_decoder=True`.


        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
2118
2119
2120
2121
2122
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     BeamSearchScorer,
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


        >>> # lets run beam search using 3 beams
        >>> num_beams = 3
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
Sylvain Gugger's avatar
Sylvain Gugger committed
2141
2142
2143
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
        ... }

        >>> # instantiate beam scorer
        >>> beam_scorer = BeamSearchScorer(
        ...     batch_size=1,
        ...     num_beams=num_beams,
        ...     device=model.device,
        ... )

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2154
2155
2156
2157
2158
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )
2159
2160
2161

        >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)

2162
2163
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt bist du?']
2164
        ```"""
2165
2166
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
2167
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2168
2169
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
2170
2171
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
2172
2173
2174
2175
2176
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
        if len(stopping_criteria) == 0:
            warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
2177
2178
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
2179
2180
2181
2182
2183
2184
2185
2186
2187
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape

        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )

2198
2199
        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
2200
2201
2202
        beam_indices = (
            tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
        )
2203
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
2204
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
2205
2206
2207
2208
2209
2210
2211
2212
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
2213

2214
2215
2216
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view((batch_size * num_beams,))
2217

2218
        this_peer_finished = False  # used by synced_gpus only
2219
        while True:
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

2231
2232
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

2233
2234
2235
2236
2237
2238
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
2239
2240
2241
2242
2243

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

2244
            next_token_logits = outputs.logits[:, -1, :]
2245
            # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
2246
            # cannot be generated both before and after the `nn.functional.log_softmax` operation.
2247
            next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
2248
2249
2250
            next_token_scores = nn.functional.log_softmax(
                next_token_logits, dim=-1
            )  # (batch_size * num_beams, vocab_size)
2251

2252
2253
            next_token_scores_processed = logits_processor(input_ids, next_token_scores)
            next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
2254
2255
2256
2257

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
2258
                    scores += (next_token_scores_processed,)
2259
2260
2261
2262
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
2263
2264
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
2265
2266
2267
2268
2269
2270
2271
2272

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

2273
2274
2275
2276
2277
2278
            # reshape for beam search
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

            next_token_scores, next_tokens = torch.topk(
                next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
2279
2280
            )

2281
            next_indices = torch_int_div(next_tokens, vocab_size)
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
            next_tokens = next_tokens % vocab_size

            # stateless
            beam_outputs = beam_scorer.process(
                input_ids,
                next_token_scores,
                next_tokens,
                next_indices,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
2292
                beam_indices=beam_indices,
2293
            )
2294

2295
2296
2297
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]
2298

2299
            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
2300

2301
2302
2303
2304
2305
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
2306

2307
2308
2309
            if return_dict_in_generate and output_scores:
                beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

2310
2311
            # increase cur_len
            cur_len = cur_len + 1
2312

2313
2314
2315
2316
2317
            if beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True
2318

2319
        sequence_outputs = beam_scorer.finalize(
2320
2321
2322
2323
2324
2325
2326
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
2327
            beam_indices=beam_indices,
2328
        )
2329

2330
2331
2332
        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None
2333

2334
2335
2336
2337
2338
            if self.config.is_encoder_decoder:
                return BeamSearchEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2339
                    beam_indices=sequence_outputs["beam_indices"],
2340
2341
2342
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
2343
                    cross_attentions=cross_attentions,
2344
2345
2346
2347
2348
2349
2350
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return BeamSearchDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2351
                    beam_indices=sequence_outputs["beam_indices"],
2352
2353
2354
2355
2356
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]
2357

2358
2359
2360
2361
2362
    def beam_sample(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
2363
        stopping_criteria: Optional[StoppingCriteriaList] = None,
2364
2365
2366
2367
        logits_warper: Optional[LogitsProcessorList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
2368
2369
2370
2371
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
2372
        synced_gpus: Optional[bool] = False,
2373
2374
        **model_kwargs,
    ) -> Union[BeamSampleOutput, torch.LongTensor]:
2375
        r"""
2376
2377
        Generates sequences of token ids for models with a language modeling head using **beam search multinomial
        sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
2378

2379
        Parameters:
2380
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
2381
                The sequence used as a prompt for the generation.
2382
            beam_scorer (`BeamScorer`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2383
2384
                A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
2385
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2386
2387
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
2388
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2389
2390
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
2391
            logits_warper (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2392
2393
2394
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
                to warp the prediction score distribution of the language modeling head applied before multinomial
                sampling at each generation step.
2395
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
2396
2397
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
2398
2399
2400
2401
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
2402
            output_attentions (`bool`, *optional*, defaults to `False`):
2403
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2404
                returned tensors for more details.
2405
            output_hidden_states (`bool`, *optional*, defaults to `False`):
2406
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
2407
                for more details.
2408
            output_scores (`bool`, *optional*, defaults to `False`):
2409
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
2410
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
2411
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2412
            synced_gpus (`bool`, *optional*, defaults to `False`):
2413
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
2414
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
2415
2416
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.
2417

2418
        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2419
            [`~generation_utils.BeamSampleDecoderOnlyOutput`], [`~generation_utils.BeamSampleEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
2420
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2421
2422
            [`~generation_utils.BeamSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.BeamSampleEncoderDecoderOutput`] if
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
            `model.config.is_encoder_decoder=True`.

        Examples:

        ```python
        >>> from transformers import (
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     TopKLogitsWarper,
        ...     TemperatureLogitsWarper,
        ...     BeamSearchScorer,
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids

        >>> # lets run beam search using 3 beams
        >>> num_beams = 3
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
Sylvain Gugger's avatar
Sylvain Gugger committed
2453
2454
2455
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
        ... }

        >>> # instantiate beam scorer
        >>> beam_scorer = BeamSearchScorer(
        ...     batch_size=1,
        ...     max_length=model.config.max_length,
        ...     num_beams=num_beams,
        ...     device=model.device,
        ... )

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2467
2468
2469
        >>> logits_processor = LogitsProcessorList(
        ...     [MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)]
        ... )
2470
        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2471
2472
2473
2474
2475
2476
        >>> logits_warper = LogitsProcessorList(
        ...     [
        ...         TopKLogitsWarper(50),
        ...         TemperatureLogitsWarper(0.7),
        ...     ]
        ... )
2477
2478
2479
2480
2481

        >>> outputs = model.beam_sample(
        ...     input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
        ... )

2482
2483
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt bist du?']
2484
        ```"""
2485
2486
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
2487
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2488
2489
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
2490
2491
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
2492
2493
2494
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
2495
2496
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
2497
2498
2499
2500
2501
2502
2503
2504
2505
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

2506
2507
2508
2509
2510
        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape

2511
2512
        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
2513
2514
2515
        beam_indices = (
            tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
        )
2516
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
2517
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
2518
2519
2520
2521
2522
2523
2524
2525
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
2526

2527
2528
2529
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
        beam_scores = beam_scores.view((batch_size * num_beams,))

2530
        this_peer_finished = False  # used by synced_gpus only
2531
        while True:
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

2543
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2544

2545
2546
2547
2548
2549
2550
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
2551
2552
2553
2554
2555

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

2556
            next_token_logits = outputs.logits[:, -1, :]
2557

2558
            # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
2559
            # cannot be generated both before and after the `nn.functional.log_softmax` operation.
2560
            next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
2561
2562
2563
            next_token_scores = nn.functional.log_softmax(
                next_token_logits, dim=-1
            )  # (batch_size * num_beams, vocab_size)
2564

2565
2566
            next_token_scores_processed = logits_processor(input_ids, next_token_scores)
            next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
2567
            next_token_scores = logits_warper(input_ids, next_token_scores)
2568

2569
2570
2571
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
2572
                    scores += (logits_warper(input_ids, next_token_scores_processed),)
2573
2574
2575
2576
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
2577
2578
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
2579
2580
2581
2582
2583
2584
2585
2586

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

2587
2588
2589
            # reshape for beam search
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
2590

2591
            probs = nn.functional.softmax(next_token_scores, dim=-1)
2592

2593
2594
            next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
            next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
2595

2596
2597
            next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
            next_tokens = torch.gather(next_tokens, -1, _indices)
2598

2599
            next_indices = torch_int_div(next_tokens, vocab_size)
2600
            next_tokens = next_tokens % vocab_size
Sylvain Gugger's avatar
Sylvain Gugger committed
2601

2602
2603
2604
2605
2606
2607
2608
2609
            # stateless
            beam_outputs = beam_scorer.process(
                input_ids,
                next_token_scores,
                next_tokens,
                next_indices,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
2610
                beam_indices=beam_indices,
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
            )
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]

            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)

2624
2625
2626
            if return_dict_in_generate and output_scores:
                beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

2627
2628
            # increase cur_len
            cur_len = cur_len + 1
2629

2630
2631
2632
2633
2634
            if beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True
2635

2636
        sequence_outputs = beam_scorer.finalize(
2637
2638
2639
2640
2641
2642
2643
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
2644
            beam_indices=beam_indices,
2645
2646
        )

2647
2648
2649
        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None
2650

2651
            if self.config.is_encoder_decoder:
2652
                return BeamSampleEncoderDecoderOutput(
2653
2654
2655
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2656
                    beam_indices=sequence_outputs["beam_indices"],
2657
2658
2659
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
2660
                    cross_attentions=cross_attentions,
2661
2662
2663
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
2664
                return BeamSampleDecoderOnlyOutput(
2665
2666
2667
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2668
                    beam_indices=sequence_outputs["beam_indices"],
2669
2670
2671
2672
2673
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]
2674

2675
2676
2677
2678
2679
    def group_beam_search(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
2680
        stopping_criteria: Optional[StoppingCriteriaList] = None,
2681
2682
2683
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
2684
2685
2686
2687
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
2688
        synced_gpus: Optional[bool] = False,
2689
        **model_kwargs,
2690
2691
    ):
        r"""
2692
2693
        Generates sequences of token ids for models with a language modeling head using **diverse beam search
        decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
2694
2695

        Parameters:
2696
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
2697
                The sequence used as a prompt for the generation.
2698
            beam_scorer (`BeamScorer`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2699
2700
                An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
2701
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2702
2703
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
2704
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2705
2706
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
2707
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
2708
2709
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
2710
2711
2712
2713
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
2714
            output_attentions (`bool`, *optional*, defaults to `False`):
2715
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2716
                returned tensors for more details.
2717
            output_hidden_states (`bool`, *optional*, defaults to `False`):
2718
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
2719
                for more details.
2720
            output_scores (`bool`, *optional*, defaults to `False`):
2721
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
2722
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
2723
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2724
            synced_gpus (`bool`, *optional*, defaults to `False`):
2725
2726
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)

2727
            model_kwargs:
2728
2729
                Additional model specific kwargs that will be forwarded to the `forward` function of the model. If
                model is an encoder-decoder model the kwargs should include `encoder_outputs`.
2730
2731

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2732
            [`~generation_utils.BeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
2733
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2734
            [`~generation_utils.BeamSearchDecoderOnlyOutput`] if [`~generation_utils.BeamSearchDecoderOnlyOutput`] if
2735
            `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2736
            [`~generation_utils.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`.
2737
2738
2739
2740
2741

        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
2742
2743
2744
2745
2746
2747
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     HammingDiversityLogitsProcessor,
        ...     BeamSearchScorer,
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


        >>> # lets run diverse beam search using 6 beams
        >>> num_beams = 6
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
Sylvain Gugger's avatar
Sylvain Gugger committed
2766
2767
2768
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
2769
2770
2771
2772
2773
2774
2775
2776
        ... }

        >>> # instantiate beam scorer
        >>> beam_scorer = BeamSearchScorer(
        ...     batch_size=1,
        ...     max_length=model.config.max_length,
        ...     num_beams=num_beams,
        ...     device=model.device,
Sylvain Gugger's avatar
Sylvain Gugger committed
2777
        ...     num_beam_groups=3,
2778
2779
2780
        ... )

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2781
2782
2783
2784
2785
2786
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3),
        ...         MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )
2787

Sylvain Gugger's avatar
Sylvain Gugger committed
2788
2789
2790
        >>> outputs = model.group_beam_search(
        ...     input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
        ... )
2791

2792
2793
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt bist du?']
2794
        ```"""
2795
2796
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
2797
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2798
2799
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
2800
2801
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
2802
2803
2804
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
2805
2806
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
2807
2808
2809
2810
2811
2812
2813
2814
2815
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

2816
2817
2818
2819
2820
2821
2822
2823
        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams
        num_beam_groups = beam_scorer.num_beam_groups
        num_sub_beams = num_beams // num_beam_groups
        device = input_ids.device

        batch_beam_size, cur_len = input_ids.shape

2824
2825
2826
2827
2828
        if return_dict_in_generate and output_scores:
            beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
        else:
            beam_indices = None

2829
2830
2831
2832
        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )
2833

2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )

2847
2848
2849
2850
2851
2852
        beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
        # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
        # the same group don't produce same tokens everytime.
        beam_scores[:, ::num_sub_beams] = 0
        beam_scores = beam_scores.view((batch_size * num_beams,))

2853
        this_peer_finished = False  # used by synced_gpus only
2854
        while True:
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

2866
2867
2868
2869
2870
2871
2872
2873
            # predicted tokens in cur_len step
            current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)

            # indices which will form the beams in the next time step
            reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)

            # do one decoder step on all beams of all sentences in batch
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2874
2875
2876
2877
2878
2879
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
2880

2881
2882
2883
2884
            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

2885
2886
2887
            if output_scores:
                processed_score = torch.zeros_like(outputs.logits[:, -1, :])

2888
2889
2890
2891
2892
2893
2894
            for beam_group_idx in range(num_beam_groups):
                group_start_idx = beam_group_idx * num_sub_beams
                group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
                group_size = group_end_idx - group_start_idx

                # indices of beams of current group among all sentences in batch
                batch_group_indices = []
2895

2896
2897
2898
2899
2900
2901
2902
2903
2904
                for batch_idx in range(batch_size):
                    batch_group_indices.extend(
                        [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
                    )
                group_input_ids = input_ids[batch_group_indices]

                # select outputs of beams of current group only
                next_token_logits = outputs.logits[batch_group_indices, -1, :]

2905
                # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
2906
                # cannot be generated both before and after the `nn.functional.log_softmax` operation.
2907
                next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
2908
2909
2910
                next_token_scores = nn.functional.log_softmax(
                    next_token_logits, dim=-1
                )  # (batch_size * group_size, vocab_size)
2911
2912
                vocab_size = next_token_scores.shape[-1]

2913
                next_token_scores_processed = logits_processor(
2914
2915
                    group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
                )
2916
2917
                next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
                next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
2918

2919
                if output_scores:
2920
                    processed_score[batch_group_indices] = next_token_scores_processed
2921
2922

                # reshape for beam search
2923
2924
2925
2926
2927
2928
                next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)

                next_token_scores, next_tokens = torch.topk(
                    next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
                )

2929
                next_indices = torch_int_div(next_tokens, vocab_size)
2930
2931
2932
                next_tokens = next_tokens % vocab_size

                # stateless
2933
                process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
2934
2935
2936
2937
2938
2939
2940
                beam_outputs = beam_scorer.process(
                    group_input_ids,
                    next_token_scores,
                    next_tokens,
                    next_indices,
                    pad_token_id=pad_token_id,
                    eos_token_id=eos_token_id,
2941
                    beam_indices=process_beam_indices,
2942
2943
2944
2945
2946
                )
                beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
                beam_next_tokens = beam_outputs["next_beam_tokens"]
                beam_idx = beam_outputs["next_beam_indices"]

2947
2948
2949
2950
2951
                if return_dict_in_generate and output_scores:
                    beam_indices[beam_group_idx] = tuple(
                        beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
                    )

2952
2953
2954
2955
2956
2957
2958
                input_ids[batch_group_indices] = group_input_ids[beam_idx]
                group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
                current_tokens[batch_group_indices] = group_input_ids[:, -1]

                # (beam_idx // group_size) -> batch_idx
                # (beam_idx % group_size) -> offset of idx inside the group
                reordering_indices[batch_group_indices] = (
2959
                    num_beams * torch_int_div(beam_idx, group_size) + group_start_idx + (beam_idx % group_size)
2960
2961
                )

2962
2963
2964
2965
2966
2967
2968
2969
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (processed_score,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
2970
2971
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
2972
2973
2974
2975
2976
2977
2978
2979

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

2980
2981
            input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

2982
2983
2984
2985
2986
2987
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices)

2988
            # increase cur_len
2989
2990
            cur_len = cur_len + 1

2991
2992
2993
2994
2995
            if beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True
2996

2997
        final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
2998
        sequence_outputs = beam_scorer.finalize(
2999
3000
3001
3002
3003
3004
3005
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
3006
            beam_indices=final_beam_indices,
3007
3008
        )

3009
3010
        if return_dict_in_generate:
            if not output_scores:
3011
                sequence_outputs["sequence_scores"] = None
3012

3013
3014
3015
3016
3017
            if self.config.is_encoder_decoder:
                return BeamSearchEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
3018
                    beam_indices=sequence_outputs["beam_indices"],
3019
                    encoder_attentions=encoder_attentions,
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
                    cross_attentions=cross_attentions,
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return BeamSearchDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
3030
                    beam_indices=sequence_outputs["beam_indices"],
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]

    def constrained_beam_search(
        self,
        input_ids: torch.LongTensor,
        constrained_beam_scorer: ConstrainedBeamSearchScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: Optional[bool] = None,
        **model_kwargs,
    ) -> Union[BeamSearchOutput, torch.LongTensor]:

        r"""
3055
3056
        Generates sequences of token ids for models with a language modeling head using **constrained beam search
        decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090

        Parameters:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                The sequence used as a prompt for the generation.
            constrained_beam_scorer (`ConstrainedBeamSearchScorer`):
                A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation, while satisfying a list of positive constraints. For more information, the
                documentation of [`ConstrainedBeamSearchScorer`] should be read.
            logits_processor (`LogitsProcessorList`, *optional*):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
            stopping_criteria (`StoppingCriteriaList`, *optional*):
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
            logits_warper (`LogitsProcessorList`, *optional*):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
                to warp the prediction score distribution of the language modeling head applied before multinomial
                sampling at each generation step.
            max_length (`int`, *optional*, defaults to 20):
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more details.
            output_hidden_states (`bool`, *optional*, defaults to `False`):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more details.
            output_scores (`bool`, *optional*, defaults to `False`):
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
3091
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
            synced_gpus (`bool`, *optional*, defaults to `False`):
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
            model_kwargs:
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.

        Return:
            [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
            [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if
            `model.config.is_encoder_decoder=True`.


        Examples:

        ```python
        >>> from transformers import (
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     ConstrainedBeamSearchScorer,
        ...     PhrasalConstraint,
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


        >>> # lets run beam search using 3 beams
        >>> num_beams = 3
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
        ... }

3139
        >>> constraint_str = "Sie"
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
        >>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1]  # slice to remove eos token
        >>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)]


        >>> # instantiate beam scorer
        >>> beam_scorer = ConstrainedBeamSearchScorer(
        ...     batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
        ... )

        >>> # instantiate logits processors
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )

        >>> outputs = model.constrained_beam_search(
        ...     input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
        ... )

3160
3161
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt sind Sie?']
3162
3163
3164
3165
3166
3167
        ```"""
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
3168
3169
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
        if len(stopping_criteria) == 0:
            warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )

        batch_size = len(constrained_beam_scorer._beam_hyps)
        num_beams = constrained_beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape

        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )

        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view((batch_size * num_beams,))

        this_peer_finished = False  # used by synced_gpus only
        while True:

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

            next_token_logits = outputs.logits[:, -1, :]
            # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
            # cannot be generated both before and after the `nn.functional.log_softmax` operation.
            next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
            next_token_scores = nn.functional.log_softmax(
                next_token_logits, dim=-1
            )  # (batch_size * num_beams, vocab_size)

            next_token_scores_processed = logits_processor(input_ids, next_token_scores)

            next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)

3251
3252
            scores_for_all_vocab = next_token_scores.clone()

3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298
3299
3300
3301
3302
3303
3304
3305
3306
3307
3308
3309
3310
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_token_scores,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

            # reshape for beam search
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

            next_token_scores, next_tokens = torch.topk(
                next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
            )

            next_indices = (next_tokens / vocab_size).long()
            next_tokens = next_tokens % vocab_size

            # stateless
            beam_outputs = constrained_beam_scorer.process(
                input_ids,
                next_token_scores,
                next_tokens,
                next_indices,
                scores_for_all_vocab,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
            )
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]

            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)

            # increase cur_len
            cur_len = cur_len + 1

            if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True

        sequence_outputs = constrained_beam_scorer.finalize(
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
        )

        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None
            if self.config.is_encoder_decoder:
                return BeamSearchEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    encoder_attentions=encoder_attentions,
3331
3332
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
3333
                    cross_attentions=cross_attentions,
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return BeamSearchDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]
3346

3347

3348
def top_k_top_p_filtering(
3349
    logits: torch.FloatTensor,
3350
3351
3352
3353
    top_k: int = 0,
    top_p: float = 1.0,
    filter_value: float = -float("Inf"),
    min_tokens_to_keep: int = 1,
3354
) -> torch.FloatTensor:
Sylvain Gugger's avatar
Sylvain Gugger committed
3355
    """
3356
    Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Sylvain Gugger's avatar
Sylvain Gugger committed
3357

Lysandre's avatar
Lysandre committed
3358
3359
    Args:
        logits: logits distribution shape (batch size, vocabulary size)
3360
        top_k (`int`, *optional*, defaults to 0):
3361
            If > 0, only keep the top k tokens with highest probability (top-k filtering)
3362
        top_p (`float`, *optional*, defaults to 1.0):
3363
3364
            If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
            filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
3365
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
3366
3367
            Minimumber of tokens we keep per batch example in the output.

Lysandre's avatar
Lysandre committed
3368
    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
3369
3370
    """
    if top_k > 0:
3371
3372
3373
        logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
            None, logits
        )
3374

3375
    if 0 <= top_p <= 1.0:
3376
3377
3378
        logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
            None, logits
        )
3379

3380
    return logits