generation_utils.py 179 KB
Newer Older
1
# coding=utf-8
2
3
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

17
import inspect
18
import warnings
19
20
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
21
22

import torch
23
import torch.distributed as dist
24
from torch import nn
25

26
from .generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
27
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
28
from .generation_logits_process import (
29
    EncoderNoRepeatNGramLogitsProcessor,
30
    ExponentialDecayLengthPenalty,
31
32
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
33
    HammingDiversityLogitsProcessor,
34
    InfNanRemoveLogitsProcessor,
35
    LogitNormalization,
36
37
38
39
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    NoBadWordsLogitsProcessor,
    NoRepeatNGramLogitsProcessor,
40
    PrefixConstrainedLogitsProcessor,
41
42
43
44
    RepetitionPenaltyLogitsProcessor,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
45
    TypicalLogitsWarper,
46
)
47
48
49
from .generation_stopping_criteria import (
    MaxLengthCriteria,
    MaxTimeCriteria,
50
    StoppingCriteria,
51
52
53
    StoppingCriteriaList,
    validate_stopping_criteria,
)
54
55
56
57
58
59
60
from .models.auto import (
    MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
    MODEL_FOR_CAUSAL_LM_MAPPING,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
    MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
    MODEL_FOR_VISION_2_SEQ_MAPPING,
)
61
from .pytorch_utils import torch_int_div
62
from .utils import ModelOutput, logging
63

Lysandre Debut's avatar
Lysandre Debut committed
64
65

logger = logging.get_logger(__name__)
66
67


68
69
70
71
72
73
74
@dataclass
class GreedySearchDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using greedy search.


    Args:
75
        sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
76
77
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
78
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
79
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
80
81
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
82
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
83
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
84
85
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
86
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
87
            `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class GreedySearchEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention
    weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
    encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)


    Args:
105
        sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
106
107
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
108
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
109
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
110
111
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
112
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
113
114
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
            sequence_length, sequence_length)`.
115
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
116
117
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
118
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
119
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
120
121
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
122
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
123
124
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
125
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
126
            `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
127
128
129
130
131
132
133
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
134
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
135
136
137
138
139
140
141
142
143
144
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class SampleDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using sampling.


    Args:
145
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
146
147
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
148
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
149
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
150
151
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`.
152
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
153
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
Sylvain Gugger's avatar
Sylvain Gugger committed
154
155
            `torch.FloatTensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length,
            sequence_length)`.
156
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
157
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
158
            `torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`.
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class SampleEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of
    the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
    attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)


    Args:
176
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
177
178
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
179
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
180
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
181
182
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`.
183
184
185
186
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape
            `(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`.
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
187
188
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size*num_return_sequences, sequence_length, hidden_size)`.
189
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
190
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
Sylvain Gugger's avatar
Sylvain Gugger committed
191
192
            `torch.FloatTensor` of shape `(batch_size*num_return_sequences, num_heads, generated_length,
            sequence_length)`.
193
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
194
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
195
196
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
197
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
198
            `torch.FloatTensor` of shape `(batch_size*num_return_sequences, generated_length, hidden_size)`.
199
200
201
202
203
204
205
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
206
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
207
208
209
210
211
212
213
214
215
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSearchDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using beam search.

    Args:
216
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
217
218
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
219
220
221
        sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
222
223
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
224
225
            Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
            with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`.
226
        beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
227
228
            Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
            `(batch_size*num_return_sequences, input_ids.shape[-1])`.
229
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
230
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
231
232
            `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
233
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
234
            `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
235
236
237
238
239
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
240
    beam_indices: Optional[torch.LongTensor] = None
241
242
243
244
245
246
247
248
249
250
251
252
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSearchEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights
    of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
    attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)

    Args:
253
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
254
255
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
256
257
258
        sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
259
260
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
261
262
            Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
            with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
263
        beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
264
265
            Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
            `(batch_size*num_return_sequences, max_length-1)`.
266
267
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
268
269
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
            sequence_length, sequence_length)`.
270
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
271
272
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
273
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
274
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
Sylvain Gugger's avatar
Sylvain Gugger committed
275
276
            `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
            sequence_length)`.
277
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
278
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
279
280
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
281
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
282
            `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
283
284
285
286
287
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
288
    beam_indices: Optional[torch.LongTensor] = None
289
290
291
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
292
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
293
294
295
296
297
298
299
300
301
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSampleDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using beam sample.

    Args:
302
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
303
304
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
305
306
307
        sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
308
309
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
310
311
            Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
            with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`.
312
        beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
313
314
            Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
            `(batch_size*num_return_sequences, input_ids.shape[-1])`.
315
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
316
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
317
318
            `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
319
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
320
            `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
321
322
323
324
325
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
326
    beam_indices: Optional[torch.LongTensor] = None
327
328
329
330
331
332
333
334
335
336
337
338
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSampleEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention
    weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
    encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)

    Args:
339
        sequences (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
340
341
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
342
343
344
        sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
345
346
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
347
348
            Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
            with each tensor of shape `(batch_size*num_beams, config.vocab_size)`).
349
350
351
        beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
            `(batch_size*num_return_sequences, max_length-1)`.
352
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
353
354
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
            sequence_length, sequence_length)`.
355
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
356
357
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size*num_beams, sequence_length, hidden_size)`.
358
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
359
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
360
361
            `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
362
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
363
364
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
365
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
366
            `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
367
368
369
370
371
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
372
    beam_indices: Optional[torch.LongTensor] = None
373
374
375
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
376
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
377
378
379
380
381
382
383
384
385
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]


386
387
class GenerationMixin:
    """
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].

    The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for:
        - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
          `do_sample=False`.
        - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
          `do_sample=True`.
        - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
          `do_sample=False`.
        - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
          `num_beams>1` and `do_sample=True`.
        - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
          `num_beams>1` and `num_beam_groups>1`.
        - *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`],
          if `constraints!=None` or `force_words_ids!=None`.
403
404
    """

405
406
407
408
409
    def _prepare_model_inputs(
        self,
        inputs: Optional[torch.Tensor] = None,
        bos_token_id: Optional[int] = None,
        model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
410
    ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
411
412
413
        """
        This function extracts the model-specific `inputs` for generation.
        """
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
        # 1. retrieve all kwargs that are non-None or non-model input related.
        # some encoder-decoder models have different names for model and encoder
        if (
            self.config.is_encoder_decoder
            and hasattr(self, "encoder")
            and self.encoder.main_input_name != self.main_input_name
        ):
            input_name = self.encoder.main_input_name
        else:
            input_name = self.main_input_name

        model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}

        # 2. check whether model_input_name is passed as kwarg
        # if yes and `inputs` is None use kwarg inputs
        inputs_kwarg = model_kwargs.pop(input_name, None)
        if inputs_kwarg is not None and inputs is not None:
431
432
            raise ValueError(
                f"`inputs`: {inputs}` were passed alongside "
433
434
                f"{input_name} which is not allowed."
                f"Make sure to either pass {inputs} or {input_name}=..."
435
            )
436
437
438
439
440
441
442
443
444
        elif inputs_kwarg is not None:
            inputs = inputs_kwarg

        # 3. models with `input_ids` can also make use of `inputs_embeds`
        if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs):
            inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"

        # 4. Only encoder-decoder models can have non `input_ids` input format
        if not self.config.is_encoder_decoder and input_name != "input_ids":
445
            raise ValueError(
446
447
448
                f"If {input_name} is passed as model-specific keyword "
                "input then model has to be an encoder-decoder and not a "
                f"{self.__class__.__name__}."
449
450
            )

451
452
453
454
455
456
457
458
459
460
        # 5. if `inputs` is still None, try to create `input_ids` from BOS token
        if inputs is None:
            inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))

        return inputs, input_name, model_kwargs

    def _can_retrieve_inputs_from_name(
        self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor]
    ) -> torch.Tensor:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
461
462
        If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved
        from name
463
464
465
466
467
468
469
470
471
472
        """
        can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set(
            inspect.signature(self.forward).parameters.keys()
        )

        if can_retrieve_inputs and inputs is not None:
            raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}")

        return can_retrieve_inputs

473
    def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
Sylvain Gugger's avatar
Sylvain Gugger committed
474
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
475
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
Sylvain Gugger's avatar
Sylvain Gugger committed
476
        """
477
478
        return logits

479
480
481
482
483
484
485
486
    def _prepare_input_ids_for_generation(
        self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput]
    ) -> torch.LongTensor:
        if self.config.is_encoder_decoder and encoder_outputs is not None:
            # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
            shape = encoder_outputs.last_hidden_state.size()[:-1]
            return torch.ones(shape, dtype=torch.long, device=self.device) * -100

487
488
489
490
491
        if bos_token_id is None:
            raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
        return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id

    def _prepare_attention_mask_for_generation(
492
        self,
493
        inputs: torch.Tensor,
494
495
        pad_token_id: Optional[int],
        eos_token_id: Optional[int],
496
    ) -> torch.LongTensor:
497
        is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
498
        is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
499
500
        is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id)

501
502
503
        # Check if input is input_ids and padded -> only then is attention_mask defined
        if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
            return inputs.ne(pad_token_id).long()
504
        else:
505
            return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
506
507

    def _prepare_encoder_decoder_kwargs_for_generation(
508
        self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
509
    ) -> Dict[str, Any]:
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
        # 1. get encoder
        encoder = self.get_encoder()

        # 2. prepare encoder args and encoder kwargs from model kwargs
        irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
        encoder_kwargs = {
            argument: value
            for argument, value in model_kwargs.items()
            if not any(argument.startswith(p) for p in irrelevant_prefix)
        }

        # 3. make sure that encoder returns `ModelOutput`
        model_input_name = model_input_name if model_input_name is not None else self.main_input_name
        encoder_kwargs["return_dict"] = True
        encoder_kwargs[model_input_name] = inputs_tensor
        model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
526

527
528
529
        return model_kwargs

    def _prepare_decoder_input_ids_for_generation(
530
531
532
533
534
        self,
        batch_size: int,
        decoder_start_token_id: int = None,
        bos_token_id: int = None,
        model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
535
        device: torch.device = None,
536
    ) -> torch.LongTensor:
537

538
539
540
541
        if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
            return model_kwargs.pop("decoder_input_ids")
        else:
            decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
542
543
544
            if device is None:
                device = self.device
            return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576

    def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
        decoder_start_token_id = (
            decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
        )
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id

        if decoder_start_token_id is not None:
            return decoder_start_token_id
        elif (
            hasattr(self.config, "decoder")
            and hasattr(self.config.decoder, "decoder_start_token_id")
            and self.config.decoder.decoder_start_token_id is not None
        ):
            return self.config.decoder.decoder_start_token_id
        elif bos_token_id is not None:
            return bos_token_id
        elif (
            hasattr(self.config, "decoder")
            and hasattr(self.config.decoder, "bos_token_id")
            and self.config.decoder.bos_token_id is not None
        ):
            return self.config.decoder.bos_token_id
        raise ValueError(
            "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
        )

    @staticmethod
    def _expand_inputs_for_generation(
        input_ids: torch.LongTensor,
        expand_size: int = 1,
        is_encoder_decoder: bool = False,
577
578
        attention_mask: Optional[torch.LongTensor] = None,
        encoder_outputs: Optional[ModelOutput] = None,
579
        **model_kwargs,
580
581
582
583
584
585
    ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
        expanded_return_idx = (
            torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
        )
        input_ids = input_ids.index_select(0, expanded_return_idx)

586
587
588
589
        if "token_type_ids" in model_kwargs:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)

590
591
592
593
        if attention_mask is not None:
            model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)

        if is_encoder_decoder:
594
595
            if encoder_outputs is None:
                raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
596
            encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
597
                0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
            )
            model_kwargs["encoder_outputs"] = encoder_outputs
        return input_ids, model_kwargs

    @staticmethod
    def _update_model_kwargs_for_generation(
        outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
    ) -> Dict[str, Any]:
        # update past
        if "past_key_values" in outputs:
            model_kwargs["past"] = outputs.past_key_values
        elif "mems" in outputs:
            model_kwargs["past"] = outputs.mems
        elif "past_buckets_states" in outputs:
            model_kwargs["past"] = outputs.past_buckets_states
        else:
            model_kwargs["past"] = None

616
617
618
619
620
        # update token_type_ids with last value
        if "token_type_ids" in model_kwargs:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)

621
622
623
624
625
626
627
628
629
630
        # update attention mask
        if not is_encoder_decoder:
            if "attention_mask" in model_kwargs:
                attention_mask = model_kwargs["attention_mask"]
                model_kwargs["attention_mask"] = torch.cat(
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
                )

        return model_kwargs

631
632
    def _reorder_cache(self, past, beam_idx):
        raise NotImplementedError(
Sylvain Gugger's avatar
Sylvain Gugger committed
633
634
            f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to"
            f" enable beam search for {self.__class__}"
635
        )
636

637
    def _get_logits_warper(
638
        self,
639
640
641
642
643
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        typical_p: Optional[float] = None,
        temperature: Optional[float] = None,
        num_beams: Optional[int] = None,
644
        renormalize_logits: Optional[bool] = None,
645
646
    ) -> LogitsProcessorList:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
647
648
        This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
        used for multinomial sampling.
649
        """
650

651
652
653
        # init warp parameters
        top_k = top_k if top_k is not None else self.config.top_k
        top_p = top_p if top_p is not None else self.config.top_p
654
        typical_p = typical_p if typical_p is not None else self.config.typical_p
655
656
657
658
659
660
        temperature = temperature if temperature is not None else self.config.temperature
        # instantiate warpers list
        warpers = LogitsProcessorList()

        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
661
662
        if temperature is not None and temperature != 1.0:
            warpers.append(TemperatureLogitsWarper(temperature))
663
664
665
666
        if top_k is not None and top_k != 0:
            warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
        if top_p is not None and top_p < 1.0:
            warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
667
668
        if typical_p is not None and typical_p < 1.0:
            warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
669
670
671
        # `LogitNormalization` should always be the last logit processor, when present
        if renormalize_logits is True:
            warpers.append(LogitNormalization())
672
673
674
675
676
677
        return warpers

    def _get_logits_processor(
        self,
        repetition_penalty: float,
        no_repeat_ngram_size: int,
678
        encoder_no_repeat_ngram_size: int,
679
        input_ids_seq_length: int,
680
        encoder_input_ids: torch.LongTensor,
681
682
        bad_words_ids: List[List[int]],
        min_length: int,
683
        max_length: int,
684
        eos_token_id: int,
685
686
        forced_bos_token_id: int,
        forced_eos_token_id: int,
687
688
        prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
        num_beams: int,
689
690
        num_beam_groups: int,
        diversity_penalty: float,
691
        remove_invalid_values: bool,
692
        exponential_decay_length_penalty: Tuple,
693
        logits_processor: Optional[LogitsProcessorList],
694
        renormalize_logits: Optional[bool],
695
696
    ) -> LogitsProcessorList:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
697
698
        This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
        instances used to modify the scores of the language model head.
699
        """
700
        processors = LogitsProcessorList()
701

702
703
704
705
706
        # init warp parameters
        repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
        no_repeat_ngram_size = (
            no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
        )
707
708
709
710
711
        encoder_no_repeat_ngram_size = (
            encoder_no_repeat_ngram_size
            if encoder_no_repeat_ngram_size is not None
            else self.config.encoder_no_repeat_ngram_size
        )
712
713
        bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
714
        diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
715
716
717
718
719
720
        forced_bos_token_id = (
            forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
        )
        forced_eos_token_id = (
            forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
        )
721
722
723
        remove_invalid_values = (
            remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
        )
724
725
726
727
728
        exponential_decay_length_penalty = (
            exponential_decay_length_penalty
            if exponential_decay_length_penalty is not None
            else self.config.exponential_decay_length_penalty
        )
729
730
731
732
        # instantiate processors list

        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
733
734
735
736
737
738
        if diversity_penalty is not None and diversity_penalty > 0.0:
            processors.append(
                HammingDiversityLogitsProcessor(
                    diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups
                )
            )
739
740
741
742
        if repetition_penalty is not None and repetition_penalty != 1.0:
            processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
        if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
            processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
743
744
745
746
747
748
749
        if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0:
            if self.config.is_encoder_decoder:
                processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids))
            else:
                raise ValueError(
                    "It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture"
                )
750
        if bad_words_ids is not None:
751
            processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
752
        if min_length is not None and eos_token_id is not None and min_length > 0:
753
            processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
754
        if prefix_allowed_tokens_fn is not None:
755
            processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups))
756
757
758
759
        if forced_bos_token_id is not None:
            processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
        if forced_eos_token_id is not None:
            processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
760
761
        if remove_invalid_values is True:
            processors.append(InfNanRemoveLogitsProcessor())
762
763
764
765
        if exponential_decay_length_penalty is not None:
            processors.append(
                ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
            )
766
        processors = self._merge_criteria_processor_list(processors, logits_processor)
767
768
769
        # `LogitNormalization` should always be the last logit processor, when present
        if renormalize_logits is True:
            processors.append(LogitNormalization())
770
        return processors
771

772
773
774
775
    def _get_stopping_criteria(
        self, max_length: Optional[int], max_time: Optional[float], stopping_criteria: Optional[StoppingCriteriaList]
    ) -> StoppingCriteriaList:
        criteria = StoppingCriteriaList()
776
        if max_length is not None:
777
            criteria.append(MaxLengthCriteria(max_length=max_length))
778
        if max_time is not None:
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
            criteria.append(MaxTimeCriteria(max_time=max_time))
        criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
        return criteria

    def _merge_criteria_processor_list(
        self,
        default_list: Union[LogitsProcessorList, StoppingCriteriaList],
        custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
    ) -> Union[LogitsProcessorList, StoppingCriteriaList]:
        if len(custom_list) == 0:
            return default_list
        for default in default_list:
            for custom in custom_list:
                if type(custom) is type(default):
                    object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
                    raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
795
796
797
798
799
                        f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
                        f" `generate`, but it has already been created with the values {default}. {default} has been"
                        " created by passing the corresponding arguments to generate or by the model's config default"
                        f" values. If you just want to change the default values of {object_type} consider passing"
                        f" them as arguments to `generate` instead of using a custom {object_type}."
800
801
802
                    )
        default_list.extend(custom_list)
        return default_list
803

804
805
806
807
808
809
810
811
812
813
    def compute_transition_beam_scores(
        self,
        sequences: torch.Tensor,
        scores: Tuple[torch.Tensor],
        beam_indices: torch.Tensor,
        eos_token_id: int = None,
    ):
        """compute the transition probabilities of sequences given generation
        scores and beam indices"""

814
        # 1. reshape scores as [vocab_size * batch_size, # generation steps]
815
816
817
818
        # with batch_size being 2 * vocab_size and # generation steps being
        # seq_len - input_length
        scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)

819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
        # 2. cut beam_indices to longest beam length
        beam_indices_mask = beam_indices < 0
        max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max()
        beam_indices = beam_indices[:, :max_beam_length]
        beam_indices_mask = beam_indices_mask[:, :max_beam_length]

        # 3. Set indices of beams that finished early to 0
        # such indices will be masked correctly afterwards
        beam_indices[beam_indices_mask] = 0

        # 4. multiply beam_indices with vocab size to gather correctly from scores
        beam_sequence_indices = beam_indices * self.config.vocab_size

        # 5. Define which indices contributed to scores
        cut_idx = sequences.shape[-1] - max_beam_length
834
        indices = sequences[:, cut_idx:] + beam_sequence_indices
835
836

        # 6. Compute scores
837
838
        transition_scores = scores.gather(0, indices)

839
840
        # 7. Mask out transition_scores of beams that stopped early
        transition_scores[beam_indices_mask] = 0
841
842
843

        return transition_scores

844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
    def _validate_model_class(self):
        """
        Confirms that the model class is compatible with generation. If not, raises an exception that points to the
        right class to use.
        """
        if not hasattr(self, "prepare_inputs_for_generation"):
            generate_compatible_mappings = [
                MODEL_FOR_CAUSAL_LM_MAPPING,
                MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
                MODEL_FOR_VISION_2_SEQ_MAPPING,
                MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
                MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
            ]
            generate_compatible_classes = set()
            for model_mapping in generate_compatible_mappings:
                supported_models = model_mapping.get(type(self.config), default=None)
                if supported_models is not None:
                    generate_compatible_classes.add(supported_models.__name__)
            exception_message = (
                f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
                "it doesn't have a language model head."
            )
            if generate_compatible_classes:
                exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
            raise TypeError(exception_message)

870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
    def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
        """Validates model kwargs for generation. Generate argument typos will also be caught here."""
        # Excludes arguments that are handled before calling any model function
        if self.config.is_encoder_decoder:
            for key in ["decoder_input_ids"]:
                model_kwargs.pop(key, None)

        unused_model_args = []
        model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
        # `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
        # `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
        if "kwargs" in model_args:
            model_args |= set(inspect.signature(self.forward).parameters)
        for key, value in model_kwargs.items():
            if value is not None and key not in model_args:
                unused_model_args.append(key)

        if unused_model_args:
            raise ValueError(
                f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
                " generate arguments will also show up in this list)"
            )

893
894
895
    @torch.no_grad()
    def generate(
        self,
896
        inputs: Optional[torch.Tensor] = None,
897
898
899
900
901
902
903
904
        max_length: Optional[int] = None,
        min_length: Optional[int] = None,
        do_sample: Optional[bool] = None,
        early_stopping: Optional[bool] = None,
        num_beams: Optional[int] = None,
        temperature: Optional[float] = None,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
905
        typical_p: Optional[float] = None,
906
907
        repetition_penalty: Optional[float] = None,
        bad_words_ids: Optional[Iterable[int]] = None,
908
        force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
909
910
911
912
913
        bos_token_id: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        no_repeat_ngram_size: Optional[int] = None,
914
        encoder_no_repeat_ngram_size: Optional[int] = None,
915
        num_return_sequences: Optional[int] = None,
916
        max_time: Optional[float] = None,
917
        max_new_tokens: Optional[int] = None,
918
919
        decoder_start_token_id: Optional[int] = None,
        use_cache: Optional[bool] = None,
920
921
        num_beam_groups: Optional[int] = None,
        diversity_penalty: Optional[float] = None,
922
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
923
        logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
924
        renormalize_logits: Optional[bool] = None,
925
        stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
926
        constraints: Optional[List[Constraint]] = None,
927
928
929
930
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
931
932
        forced_bos_token_id: Optional[int] = None,
        forced_eos_token_id: Optional[int] = None,
933
        remove_invalid_values: Optional[bool] = None,
934
        synced_gpus: Optional[bool] = False,
935
        exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
936
937
        **model_kwargs,
    ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
Sylvain Gugger's avatar
Sylvain Gugger committed
938
        r"""
939

940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
        Generates sequences of token ids for models with a language modeling head. The method supports the following
        generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:

            - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
              `do_sample=False`.
            - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
              `do_sample=True`.
            - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
              `do_sample=False`.
            - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
              `num_beams>1` and `do_sample=True`.
            - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
              `num_beams>1` and `num_beam_groups>1`.
            - *constrained beam-search decoding* by calling
              [`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or
              `force_words_ids!=None`.

        <Tip warning={true}>

        Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
        defined in the model's config (`config.json`) which in turn defaults to the
        [`~modeling_utils.PretrainedConfig`] of the model.

        </Tip>
964

Sylvain Gugger's avatar
Sylvain Gugger committed
965
966
        Most of these parameters are explained in more detail in [this blog
        post](https://huggingface.co/blog/how-to-generate).
967
968

        Parameters:
969
            inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
970
                The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
Sylvain Gugger's avatar
Sylvain Gugger committed
971
972
973
                method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
                should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
                `input_ids`, `input_values`, `input_features`, or `pixel_values`.
974
            max_length (`int`, *optional*, defaults to `model.config.max_length`):
975
976
977
978
979
                The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
                `max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in
                the prompt.
            max_new_tokens (`int`, *optional*):
                The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
980
            min_length (`int`, *optional*, defaults to `model.config.min_length` or 10 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
981
                The minimum length of the sequence to be generated.
982
            do_sample (`bool`, *optional*, defaults to `model.config.do_sample` or `False` if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
983
                Whether or not to use sampling ; use greedy decoding otherwise.
984
985
            early_stopping (`bool`, *optional*, defaults to `False`):
                Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
986
            num_beams (`int`, *optional*, defaults to `model.config.num_beams` or 1 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
987
                Number of beams for beam search. 1 means no beam search.
988
            temperature (`float`, *optional*, defaults to `model.config.temperature` or 1.0 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
989
                The value used to module the next token probabilities.
990
            top_k (`int`, *optional*, defaults to `model.config.top_k` or 50 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
991
                The number of highest probability vocabulary tokens to keep for top-k-filtering.
992
            top_p (`float`, *optional*, defaults to `model.config.top_p` or 1.0 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
993
994
                If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
                are kept for generation.
995
            typical_p (`float`, *optional*, defaults to `model.config.typical_p` or 1.0 if the config does not set any value):
Yih-Dar's avatar
Yih-Dar committed
996
997
                The amount of probability mass from the original distribution to be considered in typical decoding. If
                set to 1.0 it takes no effect. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
998
            repetition_penalty (`float`, *optional*, defaults to `model.config.repetition_penalty` or 1.0 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
999
1000
                The parameter for repetition penalty. 1.0 means no penalty. See [this
                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
1001
            pad_token_id (`int`, *optional*, defaults to `model.config.pad_token_id`):
1002
                The id of the *padding* token.
1003
            bos_token_id (`int`, *optional*, defaults to `model.config.bos_token_id`):
1004
                The id of the *beginning-of-sequence* token.
1005
            eos_token_id (`int`, *optional*, defaults to `model.config.eos_token_id`):
1006
                The id of the *end-of-sequence* token.
1007
            length_penalty (`float`, *optional*, defaults to `model.config.length_penalty` or 1.0 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
1008
1009
1010
                 Exponential penalty to the length. 1.0 means that the beam score is penalized by the sequence length.
                 0.0 means no penalty. Set to values < 0.0 in order to encourage the model to generate longer
                 sequences, to a value > 0.0 in order to encourage the model to produce shorter sequences.
1011
            no_repeat_ngram_size (`int`, *optional*, defaults to `model.config.no_repeat_ngram_size` or 0 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
1012
                If set to int > 0, all ngrams of that size can only occur once.
1013
            encoder_no_repeat_ngram_size (`int`, *optional*, defaults to `model.config.encoder_no_repeat_ngram_size` or 0 if the config does not set any value):
1014
1015
                If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
                `decoder_input_ids`.
1016
            bad_words_ids(`List[List[int]]`, *optional*, defaults to `model.config.bad_words_ids`):
1017
1018
                List of token ids that are not allowed to be generated. In order to get the token ids of the words that
                should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
1019
                add_special_tokens=False).input_ids`.
1020
1021
1022
1023
1024
            force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
                List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple
                list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`,
                this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081),
                where one can allow different forms of each word.
1025
            num_return_sequences(`int`, *optional*, defaults to `model.config.num_return_sequences` or 1 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
1026
                The number of independently computed returned sequences for each element in the batch.
1027
            max_time(`float`, *optional*):
1028
1029
                The maximum amount of time you allow the computation to run for in seconds. generation will still
                finish the current pass after allocated time has been passed.
1030
            attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1031
1032
1033
                Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens
                that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape
                as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask)
1034
1035
1036
            decoder_start_token_id (`int`, *optional*):
                If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
            use_cache: (`bool`, *optional*, defaults to `True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1037
1038
                Whether or not the model should use the past last key/values attentions (if applicable to the model) to
                speed up decoding.
1039
            num_beam_groups (`int`, *optional*, defaults to `model.config.num_beam_groups` or 1 if the config does not set any value):
1040
1041
                Number of groups to divide `num_beams` into in order to ensure diversity among different groups of
                beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
1042
            diversity_penalty (`float`, *optional*, defaults to `model.config.diversity_penalty` or 0.0 if the config does not set any value):
1043
                This value is subtracted from a beam's score if it generates a token same as any beam from other group
1044
                at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is
1045
                enabled.
1046
            prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
1047
                If provided, this function constraints the beam search to allowed tokens only at each step. If not
1048
                provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
Sylvain Gugger's avatar
Sylvain Gugger committed
1049
1050
1051
1052
                `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
                on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
                for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
                Retrieval](https://arxiv.org/abs/2010.00904).
1053
            logits_processor (`LogitsProcessorList`, *optional*):
1054
1055
1056
                 Custom logits processors that complement the default logits processors built from arguments and a
                 model's config. If a logit processor is passed that is already created with the arguments or a model's
                 config an error is thrown. This feature is intended for advanced users.
1057
1058
1059
1060
            renormalize_logits: (`bool`, *optional*, defaults to `False`):
                Whether to renormalize the logits after applying all the logits processors or warpers (including the
                custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
                score logits are normalized but some logit processors or warpers break the normalization.
1061
            stopping_criteria (`StoppingCriteriaList`, *optional*):
1062
1063
1064
                 Custom stopping criteria that complement the default stopping criteria built from arguments and a
                 model's config. If a stopping criteria is passed that is already created with the arguments or a
                 model's config an error is thrown. This feature is intended for advanced users.
1065
1066
1067
            constraints (`List[Constraint]`, *optional*):
                 Custom constraints that can be added to the generation to ensure that the output will contain the use
                 of certain tokens as defined by `Constraint` objects, in the most sensible way possible.
1068
            output_attentions (`bool`, *optional*, defaults to `model.config.output_attentions` or `False` if the config does not set any value):
1069
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1070
                returned tensors for more details.
1071
            output_hidden_states (`bool`, *optional*, defaults to `model.config.output_hidden_states` or `False` if the config does not set any value):
1072
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1073
                for more details.
1074
            output_scores (`bool`, *optional*, defaults to `model.config.output_scores` or `False` if the config does not set any value):
1075
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
1076
            return_dict_in_generate (`bool`, *optional*, defaults to `model.config.return_dict_in_generate` or `False` if the config does not set any value):
1077
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1078
            forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1079
1080
1081
                The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful
                for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be
                the target language token.
1082
            forced_eos_token_id (`int`, *optional*, defaults to `model.config.forced_eos_token_id`):
1083
                The id of the token to force as the last generated token when `max_length` is reached.
1084
            remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`):
1085
1086
1087
                Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to
                crash. Note that using `remove_invalid_values` can slow down generation.
            synced_gpus (`bool`, *optional*, defaults to `False`):
1088
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
1089
            exponential_decay_length_penalty (`tuple(int, float)`, *optional*, defaults to `model.config.exponential_decay_length_penalty`):
1090
1091
1092
1093
                This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
                generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates
                where penalty starts and `decay_factor` represents the factor of exponential decay

1094
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
1095
1096
1097
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
                is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
                should be prefixed with *decoder_*.
1098
1099

        Return:
1100
1101
            [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
            or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
1102

Sylvain Gugger's avatar
Sylvain Gugger committed
1103
                If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
1104
                [`~utils.ModelOutput`] types are:
1105
1106
1107
1108
1109
1110
1111

                    - [`~generation_utils.GreedySearchDecoderOnlyOutput`],
                    - [`~generation_utils.SampleDecoderOnlyOutput`],
                    - [`~generation_utils.BeamSearchDecoderOnlyOutput`],
                    - [`~generation_utils.BeamSampleDecoderOnlyOutput`]

                If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
1112
                [`~utils.ModelOutput`] types are:
1113
1114
1115
1116
1117
1118
1119
1120

                    - [`~generation_utils.GreedySearchEncoderDecoderOutput`],
                    - [`~generation_utils.SampleEncoderDecoderOutput`],
                    - [`~generation_utils.BeamSearchEncoderDecoderOutput`],
                    - [`~generation_utils.BeamSampleEncoderDecoderOutput`]

        Examples:

1121
1122
        Greedy Decoding:

1123
        ```python
1124
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM
1125

1126
1127
        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
1128

1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
        >>> prompt = "Today I believe we can finally"
        >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids

        >>> # generate up to 30 tokens
        >>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
        ```

        Multinomial Sampling:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
1145
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167

        >>> prompt = "Today I believe we can finally"
        >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids

        >>> # sample up to 30 tokens
        >>> torch.manual_seed(0)  # doctest: +IGNORE_RESULT
        >>> outputs = model.generate(input_ids, do_sample=True, max_length=30)
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the']
        ```

        Beam-search decoding:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

        >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")

        >>> sentence = "Paris is one of the densest populated areas in Europe."
        >>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids

1168
        >>> outputs = model.generate(input_ids, num_beams=5)
1169
1170
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
1171
        ```"""
1172
1173
        # 0. Validate the `.generate()` call
        self._validate_model_class()
1174
1175
        self._validate_model_kwargs(model_kwargs.copy())

1176
1177
        # 1. Set generation parameters if not already defined
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
1178
        num_beams = num_beams if num_beams is not None else self.config.num_beams
1179
1180
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
1181
        num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
1182
1183
1184
1185
1186
        do_sample = do_sample if do_sample is not None else self.config.do_sample
        num_return_sequences = (
            num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
        )

1187
1188
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1189

1190
1191
        if eos_token_id is None and hasattr(self.config, "decoder"):
            eos_token_id = self.config.decoder.eos_token_id
1192

1193
        if pad_token_id is None and eos_token_id is not None:
1194
1195
1196
1197
1198
            if model_kwargs.get("attention_mask", None) is None:
                logger.warning(
                    "The attention mask and the pad token id were not set. As a consequence, you may observe "
                    "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
                )
1199
1200
1201
            logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
            pad_token_id = eos_token_id

1202
1203
1204
1205
1206
1207
1208
1209
1210
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

1211
1212
1213
1214
1215
1216
1217
1218
1219
        # 2. Define model inputs
        # inputs_tensor has to be defined
        # model_input_name is defined if model-specific keyword input is passed
        # otherwise model_input_name is None
        # all model-specific keyword inputs are removed from `model_kwargs`
        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs)
        batch_size = inputs_tensor.shape[0]

        # 3. Define other model kwargs
1220
1221
        model_kwargs["output_attentions"] = output_attentions
        model_kwargs["output_hidden_states"] = output_hidden_states
1222
        model_kwargs["use_cache"] = use_cache
1223

1224
1225
1226
1227
        accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
        requires_attention_mask = "encoder_outputs" not in model_kwargs

        if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
1228
            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
1229
                inputs_tensor, pad_token_id, eos_token_id
1230
            )
1231

1232
        if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
1233
1234
1235
1236
1237
            # if model is encoder decoder encoder_outputs are created
            # and added to `model_kwargs`
            model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
                inputs_tensor, model_kwargs, model_input_name
            )
1238

1239
1240
1241
1242
1243
1244
1245
        # 4. Prepare `input_ids` which will be used for auto-regressive generation
        if self.config.is_encoder_decoder:
            input_ids = self._prepare_decoder_input_ids_for_generation(
                batch_size,
                decoder_start_token_id=decoder_start_token_id,
                bos_token_id=bos_token_id,
                model_kwargs=model_kwargs,
1246
                device=inputs_tensor.device,
1247
            )
1248
        else:
1249
1250
            # if decoder-only then inputs_tensor has to be `input_ids`
            input_ids = inputs_tensor
1251

1252
        # 5. Prepare `max_length` depending on other stopping criteria.
1253
        input_ids_seq_length = input_ids.shape[-1]
1254
        if max_length is None and max_new_tokens is None:
1255
            warnings.warn(
Stas Bekman's avatar
Stas Bekman committed
1256
                "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to "
1257
1258
1259
                f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is "
                "deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend "
                "using `max_new_tokens` to control the maximum length of the generation.",
1260
1261
                UserWarning,
            )
1262
1263
1264
1265
1266
1267
1268
1269
1270
        elif max_length is None and max_new_tokens is not None:
            max_length = max_new_tokens + input_ids_seq_length
        elif max_length is not None and max_new_tokens is not None:
            raise ValueError(
                "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
                " limit to the generated output length. Remove one of those arguments. Please refer to the"
                " documentation for more information. "
                "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
            )
1271
1272
        # default to config if still None
        max_length = max_length if max_length is not None else self.config.max_length
1273
        min_length = min_length if min_length is not None else self.config.min_length
1274

1275
1276
1277
1278
1279
        if min_length is not None and min_length > max_length:
            raise ValueError(
                f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
                f"length ({max_length})"
            )
1280
        if input_ids_seq_length >= max_length:
1281
1282
            input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
            logger.warning(
1283
1284
1285
                f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
                f" {max_length}. This can lead to unexpected behavior. You should consider increasing "
                "`max_new_tokens`."
1286
1287
            )

1288
        # 6. determine generation mode
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
        is_constraint_gen_mode = constraints is not None or force_words_ids is not None
        is_greedy_gen_mode = (
            (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
        )
        is_sample_gen_mode = (
            (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
        )
        is_beam_gen_mode = (
            (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
        )
1299
        is_beam_sample_gen_mode = (
1300
            (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
1301
        )
1302
        is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode
1303

1304
1305
1306
1307
1308
1309
        if num_beam_groups > num_beams:
            raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
        if is_group_beam_gen_mode and do_sample is True:
            raise ValueError(
                "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
            )
1310

1311
        # 7. prepare distribution pre_processing samplers
1312
1313
1314
        logits_processor = self._get_logits_processor(
            repetition_penalty=repetition_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
1315
            encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
1316
            input_ids_seq_length=input_ids_seq_length,
1317
            encoder_input_ids=inputs_tensor,
1318
1319
            bad_words_ids=bad_words_ids,
            min_length=min_length,
1320
            max_length=max_length,
1321
            eos_token_id=eos_token_id,
1322
1323
            forced_bos_token_id=forced_bos_token_id,
            forced_eos_token_id=forced_eos_token_id,
1324
1325
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            num_beams=num_beams,
1326
1327
            num_beam_groups=num_beam_groups,
            diversity_penalty=diversity_penalty,
1328
            remove_invalid_values=remove_invalid_values,
1329
            exponential_decay_length_penalty=exponential_decay_length_penalty,
1330
            logits_processor=logits_processor,
1331
            renormalize_logits=renormalize_logits,
1332
        )
1333

1334
        # 8. prepare stopping criteria
1335
1336
1337
        stopping_criteria = self._get_stopping_criteria(
            max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria
        )
1338

1339
        # 9. go into different generation modes
1340
1341
1342
1343
        if is_greedy_gen_mode:
            if num_return_sequences > 1:
                raise ValueError(
                    f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
1344
                )
1345

1346
            # 10. run greedy search
1347
1348
1349
            return self.greedy_search(
                input_ids,
                logits_processor=logits_processor,
1350
                stopping_criteria=stopping_criteria,
1351
1352
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1353
1354
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1355
                synced_gpus=synced_gpus,
1356
                **model_kwargs,
1357
            )
1358

1359
        elif is_sample_gen_mode:
1360
            # 10. prepare logits warper
1361
            logits_warper = self._get_logits_warper(
1362
1363
1364
1365
1366
1367
                top_k=top_k,
                top_p=top_p,
                typical_p=typical_p,
                temperature=temperature,
                num_beams=num_beams,
                renormalize_logits=renormalize_logits,
1368
1369
            )

1370
            # 11. expand input_ids with `num_return_sequences` additional sequences per batch
1371
1372
1373
1374
1375
1376
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids,
                expand_size=num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )
1377

1378
            # 12. run sample
1379
1380
1381
1382
            return self.sample(
                input_ids,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
1383
                stopping_criteria=stopping_criteria,
1384
1385
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1386
1387
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1388
                synced_gpus=synced_gpus,
1389
1390
                **model_kwargs,
            )
1391

1392
1393
1394
        elif is_beam_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
1395

1396
1397
1398
            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

1399
            # 10. prepare beam search scorer
1400
1401
1402
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=num_beams,
1403
                device=inputs_tensor.device,
1404
1405
1406
1407
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
            )
1408
            # 11. interleave input_ids with `num_beams` additional sequences per batch
1409
1410
1411
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
1412
            # 12. run beam search
1413
            return self.beam_search(
1414
                input_ids,
1415
1416
                beam_scorer,
                logits_processor=logits_processor,
1417
                stopping_criteria=stopping_criteria,
1418
1419
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1420
1421
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1422
                synced_gpus=synced_gpus,
1423
1424
1425
1426
                **model_kwargs,
            )

        elif is_beam_sample_gen_mode:
1427
            # 10. prepare logits warper
1428
            logits_warper = self._get_logits_warper(
1429
1430
1431
1432
1433
1434
                top_k=top_k,
                top_p=top_p,
                typical_p=typical_p,
                temperature=temperature,
                num_beams=num_beams,
                renormalize_logits=renormalize_logits,
1435
1436
            )

1437
1438
            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")
1439
            # 11. prepare beam search scorer
1440
            beam_scorer = BeamSearchScorer(
1441
                batch_size=batch_size * num_return_sequences,
1442
                num_beams=num_beams,
1443
                device=inputs_tensor.device,
1444
1445
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
1446
            )
1447

1448
            # 12. interleave input_ids with `num_beams` additional sequences per batch
1449
1450
1451
1452
1453
1454
1455
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids,
                expand_size=num_beams * num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )

1456
            # 13. run beam sample
1457
            return self.beam_sample(
1458
                input_ids,
1459
1460
1461
                beam_scorer,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
1462
                stopping_criteria=stopping_criteria,
1463
1464
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1465
1466
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1467
                synced_gpus=synced_gpus,
1468
                **model_kwargs,
1469
1470
            )

1471
1472
1473
1474
1475
1476
1477
        elif is_group_beam_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

            if num_beams % num_beam_groups != 0:
                raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")

1478
1479
1480
            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

1481
1482
            # 10. prepare beam search scorer
            beam_scorer = BeamSearchScorer(
1483
1484
                batch_size=batch_size,
                num_beams=num_beams,
1485
                max_length=stopping_criteria.max_length,
1486
                device=inputs_tensor.device,
1487
1488
1489
1490
1491
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
                num_beam_groups=num_beam_groups,
            )
1492
            # 11. interleave input_ids with `num_beams` additional sequences per batch
1493
1494
1495
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
1496
            # 12. run beam search
1497
1498
            return self.group_beam_search(
                input_ids,
1499
                beam_scorer,
1500
                logits_processor=logits_processor,
1501
                stopping_criteria=stopping_criteria,
1502
1503
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1504
1505
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1506
                synced_gpus=synced_gpus,
1507
1508
1509
                **model_kwargs,
            )

1510
1511
1512
1513
1514
1515
1516
1517
        elif is_constraint_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

            if num_beams <= 1:
1518
                raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")
1519
1520
1521
1522
1523
1524
1525

            if do_sample:
                raise ValueError("`do_sample` needs to be false for constrained generation.")

            if num_beam_groups is not None and num_beam_groups > 1:
                raise ValueError("`num_beam_groups` not supported yet for constrained generation.")

1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
            final_constraints = []
            if constraints is not None:
                final_constraints = constraints

            if force_words_ids is not None:

                def typeerror():
                    raise ValueError(
                        "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
                        f"of positive integers, but is {force_words_ids}."
                    )

                if not isinstance(force_words_ids, list) or len(force_words_ids) == 0:
                    typeerror()

                for word_ids in force_words_ids:
                    if isinstance(word_ids[0], list):
                        if not isinstance(word_ids, list) or len(word_ids) == 0:
                            typeerror()
                        if any(not isinstance(token_ids, list) for token_ids in word_ids):
                            typeerror()
                        if any(
                            any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
                            for token_ids in word_ids
                        ):
                            typeerror()

                        constraint = DisjunctiveConstraint(word_ids)
                    else:
                        if not isinstance(word_ids, list) or len(word_ids) == 0:
                            typeerror()
                        if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
                            typeerror()

                        constraint = PhrasalConstraint(word_ids)
                    final_constraints.append(constraint)

1563
1564
            # 10. prepare beam search scorer
            constrained_beam_scorer = ConstrainedBeamSearchScorer(
1565
                constraints=final_constraints,
1566
1567
                batch_size=batch_size,
                num_beams=num_beams,
1568
                device=inputs_tensor.device,
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
            )
            # 11. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
            # 12. run beam search
            return self.constrained_beam_search(
                input_ids,
                constrained_beam_scorer=constrained_beam_scorer,
                logits_processor=logits_processor,
                stopping_criteria=stopping_criteria,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

1591
    def greedy_search(
1592
        self,
1593
1594
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
1595
        stopping_criteria: Optional[StoppingCriteriaList] = None,
1596
1597
1598
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
1599
1600
1601
1602
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
1603
        synced_gpus: Optional[bool] = False,
1604
1605
        **model_kwargs,
    ) -> Union[GreedySearchOutput, torch.LongTensor]:
1606
        r"""
1607
1608
        Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
        used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
1609
1610

        Parameters:
1611
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1612
                The sequence used as a prompt for the generation.
1613
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1614
1615
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
1616
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1617
1618
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
1619

1620
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
1621
1622
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
1623
1624
1625
1626
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
1627
            output_attentions (`bool`, *optional*, defaults to `False`):
1628
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1629
                returned tensors for more details.
1630
            output_hidden_states (`bool`, *optional*, defaults to `False`):
1631
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1632
                for more details.
1633
            output_scores (`bool`, *optional*, defaults to `False`):
1634
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
1635
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
1636
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1637
            synced_gpus (`bool`, *optional*, defaults to `False`):
1638
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
1639
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
1640
1641
                Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
                If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
1642
1643

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1644
            [`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`]
Stas Bekman's avatar
Stas Bekman committed
1645
            or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
1646
1647
            [`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if
1648
1649
1650
1651
1652
1653
            `model.config.is_encoder_decoder=True`.

        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
1654
1655
1656
1657
        ...     AutoTokenizer,
        ...     AutoModelForCausalLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
1658
1659
        ...     StoppingCriteriaList,
        ...     MaxLengthCriteria,
1660
1661
1662
1663
1664
        ... )

        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")

1665
        >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token
1666
1667
        >>> model.config.pad_token_id = model.config.eos_token_id

1668
        >>> input_prompt = "It might be possible to"
1669
1670
1671
        >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
1672
1673
        >>> logits_processor = LogitsProcessorList(
        ...     [
1674
        ...         MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id),
Sylvain Gugger's avatar
Sylvain Gugger committed
1675
1676
        ...     ]
        ... )
1677
        >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
1678

1679
1680
1681
        >>> outputs = model.greedy_search(
        ...     input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
        ... )
1682

1683
1684
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ["It might be possible to get a better understanding of the nature of the problem, but it's not"]
1685
        ```"""
1686
1687
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1688
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1689
1690
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
1691
1692
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
1693
1694
1695
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
1696
1697
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
1710
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
1711
1712
1713
1714
1715
1716
1717
1718
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
1719

1720
1721
1722
        # keep track of which sequences are already finished
        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
        cur_len = input_ids.shape[-1]
1723

1724
        this_peer_finished = False  # used by synced_gpus only
1725
        while True:
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

1737
1738
            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1739

1740
            # forward pass to get next token
1741
1742
1743
1744
1745
1746
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
1747
1748
1749
1750
1751

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

1752
            next_token_logits = outputs.logits[:, -1, :]
1753

1754
1755
1756
            # pre-process distribution
            next_tokens_scores = logits_processor(input_ids, next_token_logits)

1757
1758
1759
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
1760
                    scores += (next_tokens_scores,)
1761
1762
1763
1764
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
1765
1766
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
1767
1768
1769
1770
1771
1772
1773
1774

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

1775
            # argmax
1776
            next_tokens = torch.argmax(next_tokens_scores, dim=-1)
1777

1778
            # finished sentences should have their next token be a padding token
1779
            if eos_token_id is not None:
1780
1781
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
1782
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
1783

1784
            # update generated ids, model inputs, and length for next step
1785
1786
1787
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1788
            )
1789
1790
            cur_len = cur_len + 1

1791
1792
1793
1794
1795
            # if eos_token was found in one sentence, set sentence to finished
            if eos_token_id is not None:
                unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

            # stop when each sentence is finished, or if we exceed the maximum length
1796
1797
1798
1799
1800
1801
            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True

1802
1803
1804
1805
1806
1807
1808
1809
        if return_dict_in_generate:
            if self.config.is_encoder_decoder:
                return GreedySearchEncoderDecoderOutput(
                    sequences=input_ids,
                    scores=scores,
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
1810
                    cross_attentions=cross_attentions,
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return GreedySearchDecoderOnlyOutput(
                    sequences=input_ids,
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return input_ids
1822
1823
1824
1825
1826

    def sample(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
1827
        stopping_criteria: Optional[StoppingCriteriaList] = None,
1828
1829
1830
1831
        logits_warper: Optional[LogitsProcessorList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
1832
1833
1834
1835
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
1836
        synced_gpus: Optional[bool] = False,
1837
1838
        **model_kwargs,
    ) -> Union[SampleOutput, torch.LongTensor]:
1839
        r"""
1840
1841
        Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
1842
1843

        Parameters:
1844
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1845
                The sequence used as a prompt for the generation.
1846
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1847
1848
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
1849
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1850
1851
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
1852
            logits_warper (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1853
1854
1855
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
                to warp the prediction score distribution of the language modeling head applied before multinomial
                sampling at each generation step.
1856
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
1857
1858
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
1859
1860
1861
1862
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
1863
            output_attentions (`bool`, *optional*, defaults to `False`):
1864
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1865
                returned tensors for more details.
1866
            output_hidden_states (`bool`, *optional*, defaults to `False`):
1867
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1868
                for more details.
1869
            output_scores (`bool`, *optional*, defaults to `False`):
1870
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
1871
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
1872
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1873
            synced_gpus (`bool`, *optional*, defaults to `False`):
1874
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
1875
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
1876
1877
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.
1878
1879

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1880
            [`~generation_utils.SampleDecoderOnlyOutput`], [`~generation_utils.SampleEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
1881
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
1882
1883
            [`~generation_utils.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.SampleEncoderDecoderOutput`] if
1884
1885
1886
1887
1888
1889
            `model.config.is_encoder_decoder=True`.

        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
1890
1891
1892
1893
1894
1895
        ...     AutoTokenizer,
        ...     AutoModelForCausalLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     TopKLogitsWarper,
        ...     TemperatureLogitsWarper,
1896
1897
        ...     StoppingCriteriaList,
        ...     MaxLengthCriteria,
1898
        ... )
1899
        >>> import torch
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910

        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")

        >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
        >>> model.config.pad_token_id = model.config.eos_token_id

        >>> input_prompt = "Today is a beautiful day, and"
        >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
1911
1912
1913
1914
1915
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )
1916
        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
1917
1918
1919
1920
1921
1922
        >>> logits_warper = LogitsProcessorList(
        ...     [
        ...         TopKLogitsWarper(50),
        ...         TemperatureLogitsWarper(0.7),
        ...     ]
        ... )
1923

1924
1925
1926
1927
1928
1929
1930
1931
1932
        >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])

        >>> torch.manual_seed(0)  # doctest: +IGNORE_RESULT
        >>> outputs = model.sample(
        ...     input_ids,
        ...     logits_processor=logits_processor,
        ...     logits_warper=logits_warper,
        ...     stopping_criteria=stopping_criteria,
        ... )
1933

1934
1935
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
1936
        ```"""
1937
1938
1939

        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1940
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1941
1942
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
1943
1944
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
1945
1946
1947
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
1948
        logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
1949
1950
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
1963
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
1964
1965
1966
1967
1968
1969
1970
1971
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
1972

1973
1974
1975
        # keep track of which sequences are already finished
        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
        cur_len = input_ids.shape[-1]
1976

1977
        this_peer_finished = False  # used by synced_gpus only
1978
        # auto-regressive generation
1979
        while True:
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

1991
1992
1993
1994
            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # forward pass to get next token
1995
1996
1997
1998
1999
2000
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
2001
2002
2003
2004
2005

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

2006
2007
2008
            next_token_logits = outputs.logits[:, -1, :]

            # pre-process distribution
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_token_scores = logits_warper(input_ids, next_token_scores)

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_token_scores,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
2020
2021
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
2022
2023
2024
2025
2026
2027
2028

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )
2029
2030

            # sample
2031
            probs = nn.functional.softmax(next_token_scores, dim=-1)
2032
2033
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

2034
            # finished sentences should have their next token be a padding token
2035
            if eos_token_id is not None:
2036
2037
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
2038
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
2039

2040
            # update generated ids, model inputs, and length for next step
2041
2042
2043
2044
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
2045
2046
            cur_len = cur_len + 1

2047
2048
2049
2050
2051
            # if eos_token was found in one sentence, set sentence to finished
            if eos_token_id is not None:
                unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

            # stop when each sentence is finished, or if we exceed the maximum length
2052
2053
2054
2055
2056
2057
            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True

2058
2059
2060
2061
2062
2063
2064
2065
        if return_dict_in_generate:
            if self.config.is_encoder_decoder:
                return SampleEncoderDecoderOutput(
                    sequences=input_ids,
                    scores=scores,
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
2066
                    cross_attentions=cross_attentions,
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return SampleDecoderOnlyOutput(
                    sequences=input_ids,
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return input_ids
2078

2079
    def beam_search(
2080
        self,
2081
2082
2083
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
2084
        stopping_criteria: Optional[StoppingCriteriaList] = None,
2085
2086
2087
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
2088
2089
2090
2091
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
2092
        synced_gpus: Optional[bool] = False,
2093
2094
        **model_kwargs,
    ) -> Union[BeamSearchOutput, torch.LongTensor]:
2095
        r"""
2096
2097
        Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
2098

2099
        Parameters:
2100
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
2101
                The sequence used as a prompt for the generation.
2102
            beam_scorer (`BeamScorer`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2103
2104
                An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
2105
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2106
2107
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
2108
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2109
2110
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
2111
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
2112
2113
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
2114
2115
2116
2117
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
2118
            output_attentions (`bool`, *optional*, defaults to `False`):
2119
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2120
                returned tensors for more details.
2121
            output_hidden_states (`bool`, *optional*, defaults to `False`):
2122
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
2123
                for more details.
2124
            output_scores (`bool`, *optional*, defaults to `False`):
2125
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
2126
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
2127
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2128
            synced_gpus (`bool`, *optional*, defaults to `False`):
2129
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
2130
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
2131
2132
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.
2133

2134
        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2135
            [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
2136
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2137
2138
            [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if
2139
2140
2141
2142
2143
2144
2145
            `model.config.is_encoder_decoder=True`.


        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
2146
2147
2148
2149
2150
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     BeamSearchScorer,
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


        >>> # lets run beam search using 3 beams
        >>> num_beams = 3
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
Sylvain Gugger's avatar
Sylvain Gugger committed
2169
2170
2171
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
        ... }

        >>> # instantiate beam scorer
        >>> beam_scorer = BeamSearchScorer(
        ...     batch_size=1,
        ...     num_beams=num_beams,
        ...     device=model.device,
        ... )

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2182
2183
2184
2185
2186
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )
2187
2188
2189

        >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)

2190
2191
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt bist du?']
2192
        ```"""
2193
2194
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
2195
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2196
2197
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
2198
2199
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
2200
2201
2202
2203
2204
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
        if len(stopping_criteria) == 0:
            warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
2205
2206
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
2207
2208
2209
2210
2211
2212
2213
2214
2215
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape

        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )

2226
2227
        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
2228
2229
2230
        beam_indices = (
            tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
        )
2231
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
2232
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
2233
2234
2235
2236
2237
2238
2239
2240
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
2241

2242
2243
2244
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view((batch_size * num_beams,))
2245

2246
        this_peer_finished = False  # used by synced_gpus only
2247
        while True:
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

2259
2260
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

2261
2262
2263
2264
2265
2266
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
2267
2268
2269
2270
2271

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

2272
            next_token_logits = outputs.logits[:, -1, :]
2273
            # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
2274
            # cannot be generated both before and after the `nn.functional.log_softmax` operation.
2275
            next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
2276
2277
2278
            next_token_scores = nn.functional.log_softmax(
                next_token_logits, dim=-1
            )  # (batch_size * num_beams, vocab_size)
2279

2280
2281
            next_token_scores_processed = logits_processor(input_ids, next_token_scores)
            next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
2282
2283
2284
2285

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
2286
                    scores += (next_token_scores_processed,)
2287
2288
2289
2290
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
2291
2292
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
2293
2294
2295
2296
2297
2298
2299
2300

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

2301
2302
2303
2304
2305
2306
            # reshape for beam search
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

            next_token_scores, next_tokens = torch.topk(
                next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
2307
2308
            )

2309
            next_indices = torch_int_div(next_tokens, vocab_size)
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
            next_tokens = next_tokens % vocab_size

            # stateless
            beam_outputs = beam_scorer.process(
                input_ids,
                next_token_scores,
                next_tokens,
                next_indices,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
2320
                beam_indices=beam_indices,
2321
            )
2322

2323
2324
2325
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]
2326

2327
            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
2328

2329
2330
2331
2332
2333
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
2334

2335
2336
2337
            if return_dict_in_generate and output_scores:
                beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

2338
2339
            # increase cur_len
            cur_len = cur_len + 1
2340

2341
2342
2343
2344
2345
            if beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True
2346

2347
        sequence_outputs = beam_scorer.finalize(
2348
2349
2350
2351
2352
2353
2354
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
2355
            beam_indices=beam_indices,
2356
        )
2357

2358
2359
2360
        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None
2361

2362
2363
2364
2365
2366
            if self.config.is_encoder_decoder:
                return BeamSearchEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2367
                    beam_indices=sequence_outputs["beam_indices"],
2368
2369
2370
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
2371
                    cross_attentions=cross_attentions,
2372
2373
2374
2375
2376
2377
2378
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return BeamSearchDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2379
                    beam_indices=sequence_outputs["beam_indices"],
2380
2381
2382
2383
2384
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]
2385

2386
2387
2388
2389
2390
    def beam_sample(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
2391
        stopping_criteria: Optional[StoppingCriteriaList] = None,
2392
2393
2394
2395
        logits_warper: Optional[LogitsProcessorList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
2396
2397
2398
2399
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
2400
        synced_gpus: Optional[bool] = False,
2401
2402
        **model_kwargs,
    ) -> Union[BeamSampleOutput, torch.LongTensor]:
2403
        r"""
2404
2405
        Generates sequences of token ids for models with a language modeling head using **beam search multinomial
        sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
2406

2407
        Parameters:
2408
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
2409
                The sequence used as a prompt for the generation.
2410
            beam_scorer (`BeamScorer`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2411
2412
                A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
2413
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2414
2415
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
2416
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2417
2418
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
2419
            logits_warper (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2420
2421
2422
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
                to warp the prediction score distribution of the language modeling head applied before multinomial
                sampling at each generation step.
2423
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
2424
2425
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
2426
2427
2428
2429
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
2430
            output_attentions (`bool`, *optional*, defaults to `False`):
2431
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2432
                returned tensors for more details.
2433
            output_hidden_states (`bool`, *optional*, defaults to `False`):
2434
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
2435
                for more details.
2436
            output_scores (`bool`, *optional*, defaults to `False`):
2437
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
2438
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
2439
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2440
            synced_gpus (`bool`, *optional*, defaults to `False`):
2441
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
2442
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
2443
2444
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.
2445

2446
        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2447
            [`~generation_utils.BeamSampleDecoderOnlyOutput`], [`~generation_utils.BeamSampleEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
2448
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2449
2450
            [`~generation_utils.BeamSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.BeamSampleEncoderDecoderOutput`] if
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
            `model.config.is_encoder_decoder=True`.

        Examples:

        ```python
        >>> from transformers import (
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     TopKLogitsWarper,
        ...     TemperatureLogitsWarper,
        ...     BeamSearchScorer,
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids

        >>> # lets run beam search using 3 beams
        >>> num_beams = 3
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
Sylvain Gugger's avatar
Sylvain Gugger committed
2481
2482
2483
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
        ... }

        >>> # instantiate beam scorer
        >>> beam_scorer = BeamSearchScorer(
        ...     batch_size=1,
        ...     max_length=model.config.max_length,
        ...     num_beams=num_beams,
        ...     device=model.device,
        ... )

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2495
2496
2497
        >>> logits_processor = LogitsProcessorList(
        ...     [MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)]
        ... )
2498
        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2499
2500
2501
2502
2503
2504
        >>> logits_warper = LogitsProcessorList(
        ...     [
        ...         TopKLogitsWarper(50),
        ...         TemperatureLogitsWarper(0.7),
        ...     ]
        ... )
2505
2506
2507
2508
2509

        >>> outputs = model.beam_sample(
        ...     input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
        ... )

2510
2511
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt bist du?']
2512
        ```"""
2513
2514
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
2515
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2516
2517
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
2518
2519
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
2520
2521
2522
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
2523
2524
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
2525
2526
2527
2528
2529
2530
2531
2532
2533
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

2534
2535
2536
2537
2538
        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape

2539
2540
        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
2541
2542
2543
        beam_indices = (
            tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
        )
2544
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
2545
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
2546
2547
2548
2549
2550
2551
2552
2553
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
2554

2555
2556
2557
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
        beam_scores = beam_scores.view((batch_size * num_beams,))

2558
        this_peer_finished = False  # used by synced_gpus only
2559
        while True:
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

2571
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2572

2573
2574
2575
2576
2577
2578
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
2579
2580
2581
2582
2583

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

2584
            next_token_logits = outputs.logits[:, -1, :]
2585

2586
            # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
2587
            # cannot be generated both before and after the `nn.functional.log_softmax` operation.
2588
            next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
2589
2590
2591
            next_token_scores = nn.functional.log_softmax(
                next_token_logits, dim=-1
            )  # (batch_size * num_beams, vocab_size)
2592

2593
2594
            next_token_scores_processed = logits_processor(input_ids, next_token_scores)
            next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
2595
            next_token_scores = logits_warper(input_ids, next_token_scores)
2596

2597
2598
2599
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
2600
                    scores += (logits_warper(input_ids, next_token_scores_processed),)
2601
2602
2603
2604
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
2605
2606
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
2607
2608
2609
2610
2611
2612
2613
2614

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

2615
2616
2617
            # reshape for beam search
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
2618

2619
            probs = nn.functional.softmax(next_token_scores, dim=-1)
2620

2621
2622
            next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
            next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
2623

2624
2625
            next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
            next_tokens = torch.gather(next_tokens, -1, _indices)
2626

2627
            next_indices = torch_int_div(next_tokens, vocab_size)
2628
            next_tokens = next_tokens % vocab_size
Sylvain Gugger's avatar
Sylvain Gugger committed
2629

2630
2631
2632
2633
2634
2635
2636
2637
            # stateless
            beam_outputs = beam_scorer.process(
                input_ids,
                next_token_scores,
                next_tokens,
                next_indices,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
2638
                beam_indices=beam_indices,
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
            )
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]

            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)

2652
2653
2654
            if return_dict_in_generate and output_scores:
                beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

2655
2656
            # increase cur_len
            cur_len = cur_len + 1
2657

2658
2659
2660
2661
2662
            if beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True
2663

2664
        sequence_outputs = beam_scorer.finalize(
2665
2666
2667
2668
2669
2670
2671
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
2672
            beam_indices=beam_indices,
2673
2674
        )

2675
2676
2677
        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None
2678

2679
            if self.config.is_encoder_decoder:
2680
                return BeamSampleEncoderDecoderOutput(
2681
2682
2683
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2684
                    beam_indices=sequence_outputs["beam_indices"],
2685
2686
2687
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
2688
                    cross_attentions=cross_attentions,
2689
2690
2691
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
2692
                return BeamSampleDecoderOnlyOutput(
2693
2694
2695
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2696
                    beam_indices=sequence_outputs["beam_indices"],
2697
2698
2699
2700
2701
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]
2702

2703
2704
2705
2706
2707
    def group_beam_search(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
2708
        stopping_criteria: Optional[StoppingCriteriaList] = None,
2709
2710
2711
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
2712
2713
2714
2715
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
2716
        synced_gpus: Optional[bool] = False,
2717
        **model_kwargs,
2718
2719
    ):
        r"""
2720
2721
        Generates sequences of token ids for models with a language modeling head using **diverse beam search
        decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
2722
2723

        Parameters:
2724
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
2725
                The sequence used as a prompt for the generation.
2726
            beam_scorer (`BeamScorer`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2727
2728
                An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
2729
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2730
2731
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
2732
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2733
2734
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
2735
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
2736
2737
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
2738
2739
2740
2741
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
2742
            output_attentions (`bool`, *optional*, defaults to `False`):
2743
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2744
                returned tensors for more details.
2745
            output_hidden_states (`bool`, *optional*, defaults to `False`):
2746
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
2747
                for more details.
2748
            output_scores (`bool`, *optional*, defaults to `False`):
2749
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
2750
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
2751
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2752
            synced_gpus (`bool`, *optional*, defaults to `False`):
2753
2754
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)

2755
            model_kwargs:
2756
2757
                Additional model specific kwargs that will be forwarded to the `forward` function of the model. If
                model is an encoder-decoder model the kwargs should include `encoder_outputs`.
2758
2759

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2760
            [`~generation_utils.BeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
2761
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2762
            [`~generation_utils.BeamSearchDecoderOnlyOutput`] if [`~generation_utils.BeamSearchDecoderOnlyOutput`] if
2763
            `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2764
            [`~generation_utils.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`.
2765
2766
2767
2768
2769

        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
2770
2771
2772
2773
2774
2775
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     HammingDiversityLogitsProcessor,
        ...     BeamSearchScorer,
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


        >>> # lets run diverse beam search using 6 beams
        >>> num_beams = 6
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
Sylvain Gugger's avatar
Sylvain Gugger committed
2794
2795
2796
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
2797
2798
2799
2800
2801
2802
2803
2804
        ... }

        >>> # instantiate beam scorer
        >>> beam_scorer = BeamSearchScorer(
        ...     batch_size=1,
        ...     max_length=model.config.max_length,
        ...     num_beams=num_beams,
        ...     device=model.device,
Sylvain Gugger's avatar
Sylvain Gugger committed
2805
        ...     num_beam_groups=3,
2806
2807
2808
        ... )

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2809
2810
2811
2812
2813
2814
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3),
        ...         MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )
2815

Sylvain Gugger's avatar
Sylvain Gugger committed
2816
2817
2818
        >>> outputs = model.group_beam_search(
        ...     input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
        ... )
2819

2820
2821
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt bist du?']
2822
        ```"""
2823
2824
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
2825
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2826
2827
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
2828
2829
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
2830
2831
2832
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
2833
2834
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
2835
2836
2837
2838
2839
2840
2841
2842
2843
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

2844
2845
2846
2847
2848
2849
2850
2851
        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams
        num_beam_groups = beam_scorer.num_beam_groups
        num_sub_beams = num_beams // num_beam_groups
        device = input_ids.device

        batch_beam_size, cur_len = input_ids.shape

2852
2853
2854
2855
2856
        if return_dict_in_generate and output_scores:
            beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
        else:
            beam_indices = None

2857
2858
2859
2860
        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )
2861

2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )

2875
2876
2877
2878
2879
2880
        beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
        # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
        # the same group don't produce same tokens everytime.
        beam_scores[:, ::num_sub_beams] = 0
        beam_scores = beam_scores.view((batch_size * num_beams,))

2881
        this_peer_finished = False  # used by synced_gpus only
2882
        while True:
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

2894
2895
2896
2897
2898
2899
2900
2901
            # predicted tokens in cur_len step
            current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)

            # indices which will form the beams in the next time step
            reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)

            # do one decoder step on all beams of all sentences in batch
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2902
2903
2904
2905
2906
2907
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
2908

2909
2910
2911
2912
            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

2913
2914
2915
            if output_scores:
                processed_score = torch.zeros_like(outputs.logits[:, -1, :])

2916
2917
2918
2919
2920
2921
2922
            for beam_group_idx in range(num_beam_groups):
                group_start_idx = beam_group_idx * num_sub_beams
                group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
                group_size = group_end_idx - group_start_idx

                # indices of beams of current group among all sentences in batch
                batch_group_indices = []
2923

2924
2925
2926
2927
2928
2929
2930
2931
2932
                for batch_idx in range(batch_size):
                    batch_group_indices.extend(
                        [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
                    )
                group_input_ids = input_ids[batch_group_indices]

                # select outputs of beams of current group only
                next_token_logits = outputs.logits[batch_group_indices, -1, :]

2933
                # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
2934
                # cannot be generated both before and after the `nn.functional.log_softmax` operation.
2935
                next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
2936
2937
2938
                next_token_scores = nn.functional.log_softmax(
                    next_token_logits, dim=-1
                )  # (batch_size * group_size, vocab_size)
2939
2940
                vocab_size = next_token_scores.shape[-1]

2941
                next_token_scores_processed = logits_processor(
2942
2943
                    group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
                )
2944
2945
                next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
                next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
2946

2947
                if output_scores:
2948
                    processed_score[batch_group_indices] = next_token_scores_processed
2949
2950

                # reshape for beam search
2951
2952
2953
2954
2955
2956
                next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)

                next_token_scores, next_tokens = torch.topk(
                    next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
                )

2957
                next_indices = torch_int_div(next_tokens, vocab_size)
2958
2959
2960
                next_tokens = next_tokens % vocab_size

                # stateless
2961
                process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
2962
2963
2964
2965
2966
2967
2968
                beam_outputs = beam_scorer.process(
                    group_input_ids,
                    next_token_scores,
                    next_tokens,
                    next_indices,
                    pad_token_id=pad_token_id,
                    eos_token_id=eos_token_id,
2969
                    beam_indices=process_beam_indices,
2970
2971
2972
2973
2974
                )
                beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
                beam_next_tokens = beam_outputs["next_beam_tokens"]
                beam_idx = beam_outputs["next_beam_indices"]

2975
2976
2977
2978
2979
                if return_dict_in_generate and output_scores:
                    beam_indices[beam_group_idx] = tuple(
                        beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
                    )

2980
2981
2982
2983
2984
2985
2986
                input_ids[batch_group_indices] = group_input_ids[beam_idx]
                group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
                current_tokens[batch_group_indices] = group_input_ids[:, -1]

                # (beam_idx // group_size) -> batch_idx
                # (beam_idx % group_size) -> offset of idx inside the group
                reordering_indices[batch_group_indices] = (
2987
                    num_beams * torch_int_div(beam_idx, group_size) + group_start_idx + (beam_idx % group_size)
2988
2989
                )

2990
2991
2992
2993
2994
2995
2996
2997
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (processed_score,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
2998
2999
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
3000
3001
3002
3003
3004
3005
3006
3007

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

3008
3009
            input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

3010
3011
3012
3013
3014
3015
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices)

3016
            # increase cur_len
3017
3018
            cur_len = cur_len + 1

3019
3020
3021
3022
3023
            if beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True
3024

3025
        final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
3026
        sequence_outputs = beam_scorer.finalize(
3027
3028
3029
3030
3031
3032
3033
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
3034
            beam_indices=final_beam_indices,
3035
3036
        )

3037
3038
        if return_dict_in_generate:
            if not output_scores:
3039
                sequence_outputs["sequence_scores"] = None
3040

3041
3042
3043
3044
3045
            if self.config.is_encoder_decoder:
                return BeamSearchEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
3046
                    beam_indices=sequence_outputs["beam_indices"],
3047
                    encoder_attentions=encoder_attentions,
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
                    cross_attentions=cross_attentions,
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return BeamSearchDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
3058
                    beam_indices=sequence_outputs["beam_indices"],
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]

    def constrained_beam_search(
        self,
        input_ids: torch.LongTensor,
        constrained_beam_scorer: ConstrainedBeamSearchScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: Optional[bool] = None,
        **model_kwargs,
    ) -> Union[BeamSearchOutput, torch.LongTensor]:

        r"""
3083
3084
        Generates sequences of token ids for models with a language modeling head using **constrained beam search
        decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118

        Parameters:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                The sequence used as a prompt for the generation.
            constrained_beam_scorer (`ConstrainedBeamSearchScorer`):
                A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation, while satisfying a list of positive constraints. For more information, the
                documentation of [`ConstrainedBeamSearchScorer`] should be read.
            logits_processor (`LogitsProcessorList`, *optional*):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
            stopping_criteria (`StoppingCriteriaList`, *optional*):
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
            logits_warper (`LogitsProcessorList`, *optional*):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
                to warp the prediction score distribution of the language modeling head applied before multinomial
                sampling at each generation step.
            max_length (`int`, *optional*, defaults to 20):
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more details.
            output_hidden_states (`bool`, *optional*, defaults to `False`):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more details.
            output_scores (`bool`, *optional*, defaults to `False`):
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
3119
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
            synced_gpus (`bool`, *optional*, defaults to `False`):
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
            model_kwargs:
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.

        Return:
            [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
            [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if
            `model.config.is_encoder_decoder=True`.


        Examples:

        ```python
        >>> from transformers import (
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     ConstrainedBeamSearchScorer,
        ...     PhrasalConstraint,
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


        >>> # lets run beam search using 3 beams
        >>> num_beams = 3
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
        ... }

3167
        >>> constraint_str = "Sie"
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
        >>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1]  # slice to remove eos token
        >>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)]


        >>> # instantiate beam scorer
        >>> beam_scorer = ConstrainedBeamSearchScorer(
        ...     batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
        ... )

        >>> # instantiate logits processors
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )

        >>> outputs = model.constrained_beam_search(
        ...     input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
        ... )

3188
3189
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt sind Sie?']
3190
3191
3192
3193
3194
3195
        ```"""
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
3196
3197
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
        if len(stopping_criteria) == 0:
            warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )

        batch_size = len(constrained_beam_scorer._beam_hyps)
        num_beams = constrained_beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape

        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )

        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view((batch_size * num_beams,))

        this_peer_finished = False  # used by synced_gpus only
        while True:

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

            next_token_logits = outputs.logits[:, -1, :]
            # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
            # cannot be generated both before and after the `nn.functional.log_softmax` operation.
            next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
            next_token_scores = nn.functional.log_softmax(
                next_token_logits, dim=-1
            )  # (batch_size * num_beams, vocab_size)

            next_token_scores_processed = logits_processor(input_ids, next_token_scores)

            next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)

3279
3280
            scores_for_all_vocab = next_token_scores.clone()

3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298
3299
3300
3301
3302
3303
3304
3305
3306
3307
3308
3309
3310
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
3356
3357
3358
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_token_scores,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

            # reshape for beam search
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

            next_token_scores, next_tokens = torch.topk(
                next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
            )

            next_indices = (next_tokens / vocab_size).long()
            next_tokens = next_tokens % vocab_size

            # stateless
            beam_outputs = constrained_beam_scorer.process(
                input_ids,
                next_token_scores,
                next_tokens,
                next_indices,
                scores_for_all_vocab,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
            )
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]

            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)

            # increase cur_len
            cur_len = cur_len + 1

            if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True

        sequence_outputs = constrained_beam_scorer.finalize(
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
        )

        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None
            if self.config.is_encoder_decoder:
                return BeamSearchEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    encoder_attentions=encoder_attentions,
3359
3360
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
3361
                    cross_attentions=cross_attentions,
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return BeamSearchDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]
3374

3375

3376
def top_k_top_p_filtering(
3377
    logits: torch.FloatTensor,
3378
3379
3380
3381
    top_k: int = 0,
    top_p: float = 1.0,
    filter_value: float = -float("Inf"),
    min_tokens_to_keep: int = 1,
3382
) -> torch.FloatTensor:
Sylvain Gugger's avatar
Sylvain Gugger committed
3383
    """
3384
    Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Sylvain Gugger's avatar
Sylvain Gugger committed
3385

Lysandre's avatar
Lysandre committed
3386
3387
    Args:
        logits: logits distribution shape (batch size, vocabulary size)
3388
        top_k (`int`, *optional*, defaults to 0):
3389
            If > 0, only keep the top k tokens with highest probability (top-k filtering)
3390
        top_p (`float`, *optional*, defaults to 1.0):
3391
3392
            If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
            filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
3393
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
3394
3395
            Minimumber of tokens we keep per batch example in the output.

Lysandre's avatar
Lysandre committed
3396
    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
3397
3398
    """
    if top_k > 0:
3399
3400
3401
        logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
            None, logits
        )
3402

3403
    if 0 <= top_p <= 1.0:
3404
3405
3406
        logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
            None, logits
        )
3407

3408
    return logits