"vscode:/vscode.git/clone" did not exist on "47ca0eaaac5b993da1a22e190f8c938aa647d73b"
generation_utils.py 180 KB
Newer Older
1
# coding=utf-8
2
3
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

17
import inspect
18
import warnings
19
20
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
21
22

import torch
23
import torch.distributed as dist
24
from torch import nn
25

26
from .generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
27
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
28
from .generation_logits_process import (
29
    EncoderNoRepeatNGramLogitsProcessor,
30
    ExponentialDecayLengthPenalty,
31
32
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
33
    HammingDiversityLogitsProcessor,
34
    InfNanRemoveLogitsProcessor,
35
    LogitNormalization,
36
37
38
39
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    NoBadWordsLogitsProcessor,
    NoRepeatNGramLogitsProcessor,
40
    PrefixConstrainedLogitsProcessor,
41
42
43
44
    RepetitionPenaltyLogitsProcessor,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
45
    TypicalLogitsWarper,
46
)
47
48
49
from .generation_stopping_criteria import (
    MaxLengthCriteria,
    MaxTimeCriteria,
50
    StoppingCriteria,
51
52
53
    StoppingCriteriaList,
    validate_stopping_criteria,
)
54
55
56
57
58
59
60
from .models.auto import (
    MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
    MODEL_FOR_CAUSAL_LM_MAPPING,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
    MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
    MODEL_FOR_VISION_2_SEQ_MAPPING,
)
61
from .pytorch_utils import torch_int_div
62
from .utils import ModelOutput, logging
63

Lysandre Debut's avatar
Lysandre Debut committed
64
65

logger = logging.get_logger(__name__)
66
67


68
69
70
71
72
73
74
@dataclass
class GreedySearchDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using greedy search.


    Args:
75
        sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
76
77
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
78
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
79
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
80
81
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
82
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
83
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
84
85
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
86
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
87
            `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class GreedySearchEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention
    weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
    encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)


    Args:
105
        sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
106
107
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
108
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
109
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
110
111
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
112
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
113
114
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
            sequence_length, sequence_length)`.
115
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
116
117
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
118
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
119
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
120
121
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
122
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
123
124
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
125
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
126
            `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
127
128
129
130
131
132
133
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
134
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
135
136
137
138
139
140
141
142
143
144
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class SampleDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using sampling.


    Args:
145
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
146
147
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
148
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
149
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
150
151
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`.
152
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
153
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
Sylvain Gugger's avatar
Sylvain Gugger committed
154
155
            `torch.FloatTensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length,
            sequence_length)`.
156
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
157
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
158
            `torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`.
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class SampleEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of
    the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
    attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)


    Args:
176
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
177
178
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
179
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
180
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
181
182
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`.
183
184
185
186
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape
            `(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`.
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
187
188
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size*num_return_sequences, sequence_length, hidden_size)`.
189
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
190
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
Sylvain Gugger's avatar
Sylvain Gugger committed
191
192
            `torch.FloatTensor` of shape `(batch_size*num_return_sequences, num_heads, generated_length,
            sequence_length)`.
193
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
194
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
195
196
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
197
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
198
            `torch.FloatTensor` of shape `(batch_size*num_return_sequences, generated_length, hidden_size)`.
199
200
201
202
203
204
205
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
206
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
207
208
209
210
211
212
213
214
215
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSearchDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using beam search.

    Args:
216
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
217
218
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
219
220
221
        sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
222
223
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
224
225
            Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
            with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`.
226
        beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
227
228
            Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
            `(batch_size*num_return_sequences, input_ids.shape[-1])`.
229
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
230
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
231
232
            `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
233
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
234
            `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
235
236
237
238
239
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
240
    beam_indices: Optional[torch.LongTensor] = None
241
242
243
244
245
246
247
248
249
250
251
252
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSearchEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights
    of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
    attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)

    Args:
253
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
254
255
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
256
257
258
        sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
259
260
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
261
262
            Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
            with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
263
        beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
264
265
            Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
            `(batch_size*num_return_sequences, max_length-1)`.
266
267
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
268
269
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
            sequence_length, sequence_length)`.
270
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
271
272
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
273
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
274
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
Sylvain Gugger's avatar
Sylvain Gugger committed
275
276
            `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
            sequence_length)`.
277
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
278
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
279
280
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
281
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
282
            `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
283
284
285
286
287
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
288
    beam_indices: Optional[torch.LongTensor] = None
289
290
291
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
292
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
293
294
295
296
297
298
299
300
301
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSampleDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using beam sample.

    Args:
302
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
303
304
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
305
306
307
        sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
308
309
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
310
311
            Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
            with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`.
312
        beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
313
314
            Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
            `(batch_size*num_return_sequences, input_ids.shape[-1])`.
315
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
316
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
317
318
            `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
319
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
320
            `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
321
322
323
324
325
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
326
    beam_indices: Optional[torch.LongTensor] = None
327
328
329
330
331
332
333
334
335
336
337
338
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSampleEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention
    weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
    encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)

    Args:
339
        sequences (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
340
341
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
342
343
344
        sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
345
346
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
347
348
            Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
            with each tensor of shape `(batch_size*num_beams, config.vocab_size)`).
349
350
351
        beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
            `(batch_size*num_return_sequences, max_length-1)`.
352
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
353
354
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
            sequence_length, sequence_length)`.
355
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
356
357
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size*num_beams, sequence_length, hidden_size)`.
358
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
359
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
360
361
            `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
362
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
363
364
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
365
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
366
            `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
367
368
369
370
371
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
372
    beam_indices: Optional[torch.LongTensor] = None
373
374
375
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
376
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
377
378
379
380
381
382
383
384
385
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]


386
387
class GenerationMixin:
    """
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].

    The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for:
        - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
          `do_sample=False`.
        - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
          `do_sample=True`.
        - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
          `do_sample=False`.
        - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
          `num_beams>1` and `do_sample=True`.
        - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
          `num_beams>1` and `num_beam_groups>1`.
        - *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`],
          if `constraints!=None` or `force_words_ids!=None`.
403
404
    """

405
406
407
408
409
    def _prepare_model_inputs(
        self,
        inputs: Optional[torch.Tensor] = None,
        bos_token_id: Optional[int] = None,
        model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
410
    ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
411
412
413
        """
        This function extracts the model-specific `inputs` for generation.
        """
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
        # 1. retrieve all kwargs that are non-None or non-model input related.
        # some encoder-decoder models have different names for model and encoder
        if (
            self.config.is_encoder_decoder
            and hasattr(self, "encoder")
            and self.encoder.main_input_name != self.main_input_name
        ):
            input_name = self.encoder.main_input_name
        else:
            input_name = self.main_input_name

        model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}

        # 2. check whether model_input_name is passed as kwarg
        # if yes and `inputs` is None use kwarg inputs
        inputs_kwarg = model_kwargs.pop(input_name, None)
        if inputs_kwarg is not None and inputs is not None:
431
432
            raise ValueError(
                f"`inputs`: {inputs}` were passed alongside "
433
434
                f"{input_name} which is not allowed."
                f"Make sure to either pass {inputs} or {input_name}=..."
435
            )
436
437
438
439
440
441
442
443
444
        elif inputs_kwarg is not None:
            inputs = inputs_kwarg

        # 3. models with `input_ids` can also make use of `inputs_embeds`
        if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs):
            inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"

        # 4. Only encoder-decoder models can have non `input_ids` input format
        if not self.config.is_encoder_decoder and input_name != "input_ids":
445
            raise ValueError(
446
447
448
                f"If {input_name} is passed as model-specific keyword "
                "input then model has to be an encoder-decoder and not a "
                f"{self.__class__.__name__}."
449
450
            )

451
452
453
454
455
456
457
458
459
460
        # 5. if `inputs` is still None, try to create `input_ids` from BOS token
        if inputs is None:
            inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))

        return inputs, input_name, model_kwargs

    def _can_retrieve_inputs_from_name(
        self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor]
    ) -> torch.Tensor:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
461
462
        If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved
        from name
463
464
465
466
467
468
469
470
471
472
        """
        can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set(
            inspect.signature(self.forward).parameters.keys()
        )

        if can_retrieve_inputs and inputs is not None:
            raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}")

        return can_retrieve_inputs

473
    def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
Sylvain Gugger's avatar
Sylvain Gugger committed
474
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
475
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
Sylvain Gugger's avatar
Sylvain Gugger committed
476
        """
477
478
        return logits

479
480
481
482
483
484
485
486
    def _prepare_input_ids_for_generation(
        self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput]
    ) -> torch.LongTensor:
        if self.config.is_encoder_decoder and encoder_outputs is not None:
            # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
            shape = encoder_outputs.last_hidden_state.size()[:-1]
            return torch.ones(shape, dtype=torch.long, device=self.device) * -100

487
488
489
490
491
        if bos_token_id is None:
            raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
        return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id

    def _prepare_attention_mask_for_generation(
492
        self,
493
        inputs: torch.Tensor,
494
495
        pad_token_id: Optional[int],
        eos_token_id: Optional[int],
496
    ) -> torch.LongTensor:
497
        is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
498
        is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
499
500
        is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id)

501
502
503
        # Check if input is input_ids and padded -> only then is attention_mask defined
        if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
            return inputs.ne(pad_token_id).long()
504
        else:
505
            return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
506
507

    def _prepare_encoder_decoder_kwargs_for_generation(
508
        self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
509
    ) -> Dict[str, Any]:
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
        # 1. get encoder
        encoder = self.get_encoder()

        # 2. prepare encoder args and encoder kwargs from model kwargs
        irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
        encoder_kwargs = {
            argument: value
            for argument, value in model_kwargs.items()
            if not any(argument.startswith(p) for p in irrelevant_prefix)
        }

        # 3. make sure that encoder returns `ModelOutput`
        model_input_name = model_input_name if model_input_name is not None else self.main_input_name
        encoder_kwargs["return_dict"] = True
        encoder_kwargs[model_input_name] = inputs_tensor
        model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
526

527
528
529
        return model_kwargs

    def _prepare_decoder_input_ids_for_generation(
530
531
532
533
534
        self,
        batch_size: int,
        decoder_start_token_id: int = None,
        bos_token_id: int = None,
        model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
535
        device: torch.device = None,
536
    ) -> torch.LongTensor:
537

538
539
540
541
        if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
            return model_kwargs.pop("decoder_input_ids")
        else:
            decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
542
543
544
            if device is None:
                device = self.device
            return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576

    def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
        decoder_start_token_id = (
            decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
        )
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id

        if decoder_start_token_id is not None:
            return decoder_start_token_id
        elif (
            hasattr(self.config, "decoder")
            and hasattr(self.config.decoder, "decoder_start_token_id")
            and self.config.decoder.decoder_start_token_id is not None
        ):
            return self.config.decoder.decoder_start_token_id
        elif bos_token_id is not None:
            return bos_token_id
        elif (
            hasattr(self.config, "decoder")
            and hasattr(self.config.decoder, "bos_token_id")
            and self.config.decoder.bos_token_id is not None
        ):
            return self.config.decoder.bos_token_id
        raise ValueError(
            "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
        )

    @staticmethod
    def _expand_inputs_for_generation(
        input_ids: torch.LongTensor,
        expand_size: int = 1,
        is_encoder_decoder: bool = False,
577
578
        attention_mask: Optional[torch.LongTensor] = None,
        encoder_outputs: Optional[ModelOutput] = None,
579
        **model_kwargs,
580
581
582
583
584
585
    ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
        expanded_return_idx = (
            torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
        )
        input_ids = input_ids.index_select(0, expanded_return_idx)

586
587
588
589
        if "token_type_ids" in model_kwargs:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)

590
591
592
593
        if attention_mask is not None:
            model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)

        if is_encoder_decoder:
594
595
            if encoder_outputs is None:
                raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
596
            encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
597
                0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
            )
            model_kwargs["encoder_outputs"] = encoder_outputs
        return input_ids, model_kwargs

    @staticmethod
    def _update_model_kwargs_for_generation(
        outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
    ) -> Dict[str, Any]:
        # update past
        if "past_key_values" in outputs:
            model_kwargs["past"] = outputs.past_key_values
        elif "mems" in outputs:
            model_kwargs["past"] = outputs.mems
        elif "past_buckets_states" in outputs:
            model_kwargs["past"] = outputs.past_buckets_states
        else:
            model_kwargs["past"] = None

616
617
618
619
620
        # update token_type_ids with last value
        if "token_type_ids" in model_kwargs:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)

621
622
623
624
625
626
627
628
629
630
        # update attention mask
        if not is_encoder_decoder:
            if "attention_mask" in model_kwargs:
                attention_mask = model_kwargs["attention_mask"]
                model_kwargs["attention_mask"] = torch.cat(
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
                )

        return model_kwargs

631
632
    def _reorder_cache(self, past, beam_idx):
        raise NotImplementedError(
Sylvain Gugger's avatar
Sylvain Gugger committed
633
634
            f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to"
            f" enable beam search for {self.__class__}"
635
        )
636

637
    def _get_logits_warper(
638
        self,
639
640
641
642
643
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        typical_p: Optional[float] = None,
        temperature: Optional[float] = None,
        num_beams: Optional[int] = None,
644
        renormalize_logits: Optional[bool] = None,
645
646
    ) -> LogitsProcessorList:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
647
648
        This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
        used for multinomial sampling.
649
        """
650

651
652
653
        # init warp parameters
        top_k = top_k if top_k is not None else self.config.top_k
        top_p = top_p if top_p is not None else self.config.top_p
654
        typical_p = typical_p if typical_p is not None else self.config.typical_p
655
656
657
658
659
660
        temperature = temperature if temperature is not None else self.config.temperature
        # instantiate warpers list
        warpers = LogitsProcessorList()

        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
661
662
        if temperature is not None and temperature != 1.0:
            warpers.append(TemperatureLogitsWarper(temperature))
663
664
665
666
        if top_k is not None and top_k != 0:
            warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
        if top_p is not None and top_p < 1.0:
            warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
667
668
        if typical_p is not None and typical_p < 1.0:
            warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
669
670
671
        # `LogitNormalization` should always be the last logit processor, when present
        if renormalize_logits is True:
            warpers.append(LogitNormalization())
672
673
674
675
676
677
        return warpers

    def _get_logits_processor(
        self,
        repetition_penalty: float,
        no_repeat_ngram_size: int,
678
        encoder_no_repeat_ngram_size: int,
679
        input_ids_seq_length: int,
680
        encoder_input_ids: torch.LongTensor,
681
682
        bad_words_ids: List[List[int]],
        min_length: int,
683
        max_length: int,
684
        eos_token_id: int,
685
686
        forced_bos_token_id: int,
        forced_eos_token_id: int,
687
688
        prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
        num_beams: int,
689
690
        num_beam_groups: int,
        diversity_penalty: float,
691
        remove_invalid_values: bool,
692
        exponential_decay_length_penalty: Tuple,
693
        logits_processor: Optional[LogitsProcessorList],
694
        renormalize_logits: Optional[bool],
695
696
    ) -> LogitsProcessorList:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
697
698
        This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
        instances used to modify the scores of the language model head.
699
        """
700
        processors = LogitsProcessorList()
701

702
703
704
705
706
        # init warp parameters
        repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
        no_repeat_ngram_size = (
            no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
        )
707
708
709
710
711
        encoder_no_repeat_ngram_size = (
            encoder_no_repeat_ngram_size
            if encoder_no_repeat_ngram_size is not None
            else self.config.encoder_no_repeat_ngram_size
        )
712
713
        bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
714
        diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
715
716
717
718
719
720
        forced_bos_token_id = (
            forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
        )
        forced_eos_token_id = (
            forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
        )
721
722
723
        remove_invalid_values = (
            remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
        )
724
725
726
727
728
        exponential_decay_length_penalty = (
            exponential_decay_length_penalty
            if exponential_decay_length_penalty is not None
            else self.config.exponential_decay_length_penalty
        )
729
730
731
732
        # instantiate processors list

        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
733
734
735
736
737
738
        if diversity_penalty is not None and diversity_penalty > 0.0:
            processors.append(
                HammingDiversityLogitsProcessor(
                    diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups
                )
            )
739
740
741
742
        if repetition_penalty is not None and repetition_penalty != 1.0:
            processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
        if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
            processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
743
744
745
746
747
748
749
        if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0:
            if self.config.is_encoder_decoder:
                processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids))
            else:
                raise ValueError(
                    "It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture"
                )
750
        if bad_words_ids is not None:
751
            processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
752
        if min_length is not None and eos_token_id is not None and min_length > 0:
753
            processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
754
        if prefix_allowed_tokens_fn is not None:
755
            processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups))
756
757
758
759
        if forced_bos_token_id is not None:
            processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
        if forced_eos_token_id is not None:
            processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
760
761
        if remove_invalid_values is True:
            processors.append(InfNanRemoveLogitsProcessor())
762
763
764
765
        if exponential_decay_length_penalty is not None:
            processors.append(
                ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
            )
766
        processors = self._merge_criteria_processor_list(processors, logits_processor)
767
768
769
        # `LogitNormalization` should always be the last logit processor, when present
        if renormalize_logits is True:
            processors.append(LogitNormalization())
770
        return processors
771

772
773
774
775
    def _get_stopping_criteria(
        self, max_length: Optional[int], max_time: Optional[float], stopping_criteria: Optional[StoppingCriteriaList]
    ) -> StoppingCriteriaList:
        criteria = StoppingCriteriaList()
776
        if max_length is not None:
777
            criteria.append(MaxLengthCriteria(max_length=max_length))
778
        if max_time is not None:
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
            criteria.append(MaxTimeCriteria(max_time=max_time))
        criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
        return criteria

    def _merge_criteria_processor_list(
        self,
        default_list: Union[LogitsProcessorList, StoppingCriteriaList],
        custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
    ) -> Union[LogitsProcessorList, StoppingCriteriaList]:
        if len(custom_list) == 0:
            return default_list
        for default in default_list:
            for custom in custom_list:
                if type(custom) is type(default):
                    object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
                    raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
795
796
797
798
799
                        f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
                        f" `generate`, but it has already been created with the values {default}. {default} has been"
                        " created by passing the corresponding arguments to generate or by the model's config default"
                        f" values. If you just want to change the default values of {object_type} consider passing"
                        f" them as arguments to `generate` instead of using a custom {object_type}."
800
801
802
                    )
        default_list.extend(custom_list)
        return default_list
803

804
805
806
807
808
809
810
811
812
813
    def compute_transition_beam_scores(
        self,
        sequences: torch.Tensor,
        scores: Tuple[torch.Tensor],
        beam_indices: torch.Tensor,
        eos_token_id: int = None,
    ):
        """compute the transition probabilities of sequences given generation
        scores and beam indices"""

814
        # 1. reshape scores as [vocab_size * batch_size, # generation steps]
815
816
817
818
        # with batch_size being 2 * vocab_size and # generation steps being
        # seq_len - input_length
        scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)

819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
        # 2. cut beam_indices to longest beam length
        beam_indices_mask = beam_indices < 0
        max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max()
        beam_indices = beam_indices[:, :max_beam_length]
        beam_indices_mask = beam_indices_mask[:, :max_beam_length]

        # 3. Set indices of beams that finished early to 0
        # such indices will be masked correctly afterwards
        beam_indices[beam_indices_mask] = 0

        # 4. multiply beam_indices with vocab size to gather correctly from scores
        beam_sequence_indices = beam_indices * self.config.vocab_size

        # 5. Define which indices contributed to scores
        cut_idx = sequences.shape[-1] - max_beam_length
834
        indices = sequences[:, cut_idx:] + beam_sequence_indices
835
836

        # 6. Compute scores
837
838
        transition_scores = scores.gather(0, indices)

839
840
        # 7. Mask out transition_scores of beams that stopped early
        transition_scores[beam_indices_mask] = 0
841
842
843

        return transition_scores

844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
    def _validate_model_class(self):
        """
        Confirms that the model class is compatible with generation. If not, raises an exception that points to the
        right class to use.
        """
        if not hasattr(self, "prepare_inputs_for_generation"):
            generate_compatible_mappings = [
                MODEL_FOR_CAUSAL_LM_MAPPING,
                MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
                MODEL_FOR_VISION_2_SEQ_MAPPING,
                MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
                MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
            ]
            generate_compatible_classes = set()
            for model_mapping in generate_compatible_mappings:
                supported_models = model_mapping.get(type(self.config), default=None)
                if supported_models is not None:
                    generate_compatible_classes.add(supported_models.__name__)
            exception_message = (
                f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
                "it doesn't have a language model head."
            )
            if generate_compatible_classes:
                exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
            raise TypeError(exception_message)

870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
    def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
        """Validates model kwargs for generation. Generate argument typos will also be caught here."""
        # Excludes arguments that are handled before calling any model function
        if self.config.is_encoder_decoder:
            for key in ["decoder_input_ids"]:
                model_kwargs.pop(key, None)

        unused_model_args = []
        model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
        # `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
        # `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
        if "kwargs" in model_args:
            model_args |= set(inspect.signature(self.forward).parameters)
        for key, value in model_kwargs.items():
            if value is not None and key not in model_args:
                unused_model_args.append(key)

        if unused_model_args:
            raise ValueError(
                f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
                " generate arguments will also show up in this list)"
            )

893
894
895
    @torch.no_grad()
    def generate(
        self,
896
        inputs: Optional[torch.Tensor] = None,
897
898
899
900
901
902
903
904
        max_length: Optional[int] = None,
        min_length: Optional[int] = None,
        do_sample: Optional[bool] = None,
        early_stopping: Optional[bool] = None,
        num_beams: Optional[int] = None,
        temperature: Optional[float] = None,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
905
        typical_p: Optional[float] = None,
906
907
        repetition_penalty: Optional[float] = None,
        bad_words_ids: Optional[Iterable[int]] = None,
908
        force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
909
910
911
912
913
        bos_token_id: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        no_repeat_ngram_size: Optional[int] = None,
914
        encoder_no_repeat_ngram_size: Optional[int] = None,
915
        num_return_sequences: Optional[int] = None,
916
        max_time: Optional[float] = None,
917
        max_new_tokens: Optional[int] = None,
918
919
        decoder_start_token_id: Optional[int] = None,
        use_cache: Optional[bool] = None,
920
921
        num_beam_groups: Optional[int] = None,
        diversity_penalty: Optional[float] = None,
922
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
923
        logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
924
        renormalize_logits: Optional[bool] = None,
925
        stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
926
        constraints: Optional[List[Constraint]] = None,
927
928
929
930
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
931
932
        forced_bos_token_id: Optional[int] = None,
        forced_eos_token_id: Optional[int] = None,
933
        remove_invalid_values: Optional[bool] = None,
934
        synced_gpus: Optional[bool] = False,
935
        exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
936
937
        **model_kwargs,
    ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
Sylvain Gugger's avatar
Sylvain Gugger committed
938
        r"""
939

940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
        Generates sequences of token ids for models with a language modeling head. The method supports the following
        generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:

            - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
              `do_sample=False`.
            - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
              `do_sample=True`.
            - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
              `do_sample=False`.
            - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
              `num_beams>1` and `do_sample=True`.
            - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
              `num_beams>1` and `num_beam_groups>1`.
            - *constrained beam-search decoding* by calling
              [`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or
              `force_words_ids!=None`.

        <Tip warning={true}>

        Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
        defined in the model's config (`config.json`) which in turn defaults to the
        [`~modeling_utils.PretrainedConfig`] of the model.

        </Tip>
964

Sylvain Gugger's avatar
Sylvain Gugger committed
965
966
        Most of these parameters are explained in more detail in [this blog
        post](https://huggingface.co/blog/how-to-generate).
967
968

        Parameters:
969
            inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
970
                The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
Sylvain Gugger's avatar
Sylvain Gugger committed
971
972
973
                method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
                should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
                `input_ids`, `input_values`, `input_features`, or `pixel_values`.
974
            max_length (`int`, *optional*, defaults to `model.config.max_length`):
975
976
977
978
979
                The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
                `max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in
                the prompt.
            max_new_tokens (`int`, *optional*):
                The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
980
            min_length (`int`, *optional*, defaults to `model.config.min_length` or 10 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
981
                The minimum length of the sequence to be generated.
982
            do_sample (`bool`, *optional*, defaults to `model.config.do_sample` or `False` if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
983
                Whether or not to use sampling ; use greedy decoding otherwise.
984
985
            early_stopping (`bool`, *optional*, defaults to `False`):
                Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
986
            num_beams (`int`, *optional*, defaults to `model.config.num_beams` or 1 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
987
                Number of beams for beam search. 1 means no beam search.
988
            temperature (`float`, *optional*, defaults to `model.config.temperature` or 1.0 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
989
                The value used to module the next token probabilities.
990
            top_k (`int`, *optional*, defaults to `model.config.top_k` or 50 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
991
                The number of highest probability vocabulary tokens to keep for top-k-filtering.
992
            top_p (`float`, *optional*, defaults to `model.config.top_p` or 1.0 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
993
994
                If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
                are kept for generation.
995
            typical_p (`float`, *optional*, defaults to `model.config.typical_p` or 1.0 if the config does not set any value):
Yih-Dar's avatar
Yih-Dar committed
996
997
                The amount of probability mass from the original distribution to be considered in typical decoding. If
                set to 1.0 it takes no effect. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
998
            repetition_penalty (`float`, *optional*, defaults to `model.config.repetition_penalty` or 1.0 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
999
1000
                The parameter for repetition penalty. 1.0 means no penalty. See [this
                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
1001
            pad_token_id (`int`, *optional*, defaults to `model.config.pad_token_id`):
1002
                The id of the *padding* token.
1003
            bos_token_id (`int`, *optional*, defaults to `model.config.bos_token_id`):
1004
                The id of the *beginning-of-sequence* token.
1005
            eos_token_id (`int`, *optional*, defaults to `model.config.eos_token_id`):
1006
                The id of the *end-of-sequence* token.
1007
            length_penalty (`float`, *optional*, defaults to `model.config.length_penalty` or 1.0 if the config does not set any value):
1008
1009
1010
1011
                Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent
                to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
                the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
                while `length_penalty` < 0.0 encourages shorter sequences.
1012
            no_repeat_ngram_size (`int`, *optional*, defaults to `model.config.no_repeat_ngram_size` or 0 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
1013
                If set to int > 0, all ngrams of that size can only occur once.
1014
            encoder_no_repeat_ngram_size (`int`, *optional*, defaults to `model.config.encoder_no_repeat_ngram_size` or 0 if the config does not set any value):
1015
1016
                If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
                `decoder_input_ids`.
1017
            bad_words_ids(`List[List[int]]`, *optional*, defaults to `model.config.bad_words_ids`):
1018
1019
                List of token ids that are not allowed to be generated. In order to get the token ids of the words that
                should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
1020
                add_special_tokens=False).input_ids`.
1021
1022
1023
1024
1025
            force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
                List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple
                list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`,
                this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081),
                where one can allow different forms of each word.
1026
            num_return_sequences(`int`, *optional*, defaults to `model.config.num_return_sequences` or 1 if the config does not set any value):
Sylvain Gugger's avatar
Sylvain Gugger committed
1027
                The number of independently computed returned sequences for each element in the batch.
1028
            max_time(`float`, *optional*):
1029
1030
                The maximum amount of time you allow the computation to run for in seconds. generation will still
                finish the current pass after allocated time has been passed.
1031
            attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1032
1033
1034
                Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens
                that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape
                as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask)
1035
1036
1037
            decoder_start_token_id (`int`, *optional*):
                If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
            use_cache: (`bool`, *optional*, defaults to `True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1038
1039
                Whether or not the model should use the past last key/values attentions (if applicable to the model) to
                speed up decoding.
1040
            num_beam_groups (`int`, *optional*, defaults to `model.config.num_beam_groups` or 1 if the config does not set any value):
1041
1042
                Number of groups to divide `num_beams` into in order to ensure diversity among different groups of
                beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
1043
            diversity_penalty (`float`, *optional*, defaults to `model.config.diversity_penalty` or 0.0 if the config does not set any value):
1044
                This value is subtracted from a beam's score if it generates a token same as any beam from other group
1045
                at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is
1046
                enabled.
1047
            prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
1048
                If provided, this function constraints the beam search to allowed tokens only at each step. If not
1049
                provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
Sylvain Gugger's avatar
Sylvain Gugger committed
1050
1051
1052
1053
                `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
                on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
                for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
                Retrieval](https://arxiv.org/abs/2010.00904).
1054
            logits_processor (`LogitsProcessorList`, *optional*):
1055
1056
1057
                 Custom logits processors that complement the default logits processors built from arguments and a
                 model's config. If a logit processor is passed that is already created with the arguments or a model's
                 config an error is thrown. This feature is intended for advanced users.
1058
1059
1060
1061
            renormalize_logits: (`bool`, *optional*, defaults to `False`):
                Whether to renormalize the logits after applying all the logits processors or warpers (including the
                custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
                score logits are normalized but some logit processors or warpers break the normalization.
1062
            stopping_criteria (`StoppingCriteriaList`, *optional*):
1063
1064
1065
                 Custom stopping criteria that complement the default stopping criteria built from arguments and a
                 model's config. If a stopping criteria is passed that is already created with the arguments or a
                 model's config an error is thrown. This feature is intended for advanced users.
1066
1067
1068
            constraints (`List[Constraint]`, *optional*):
                 Custom constraints that can be added to the generation to ensure that the output will contain the use
                 of certain tokens as defined by `Constraint` objects, in the most sensible way possible.
1069
            output_attentions (`bool`, *optional*, defaults to `model.config.output_attentions` or `False` if the config does not set any value):
1070
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1071
                returned tensors for more details.
1072
            output_hidden_states (`bool`, *optional*, defaults to `model.config.output_hidden_states` or `False` if the config does not set any value):
1073
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1074
                for more details.
1075
            output_scores (`bool`, *optional*, defaults to `model.config.output_scores` or `False` if the config does not set any value):
1076
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
1077
            return_dict_in_generate (`bool`, *optional*, defaults to `model.config.return_dict_in_generate` or `False` if the config does not set any value):
1078
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1079
            forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1080
1081
1082
                The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful
                for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be
                the target language token.
1083
            forced_eos_token_id (`int`, *optional*, defaults to `model.config.forced_eos_token_id`):
1084
                The id of the token to force as the last generated token when `max_length` is reached.
1085
            remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`):
1086
1087
1088
                Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to
                crash. Note that using `remove_invalid_values` can slow down generation.
            synced_gpus (`bool`, *optional*, defaults to `False`):
1089
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
1090
            exponential_decay_length_penalty (`tuple(int, float)`, *optional*, defaults to `model.config.exponential_decay_length_penalty`):
1091
1092
1093
1094
                This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
                generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates
                where penalty starts and `decay_factor` represents the factor of exponential decay

1095
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
1096
1097
1098
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
                is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
                should be prefixed with *decoder_*.
1099
1100

        Return:
1101
1102
            [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
            or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
1103

Sylvain Gugger's avatar
Sylvain Gugger committed
1104
                If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
1105
                [`~utils.ModelOutput`] types are:
1106
1107
1108
1109
1110
1111
1112

                    - [`~generation_utils.GreedySearchDecoderOnlyOutput`],
                    - [`~generation_utils.SampleDecoderOnlyOutput`],
                    - [`~generation_utils.BeamSearchDecoderOnlyOutput`],
                    - [`~generation_utils.BeamSampleDecoderOnlyOutput`]

                If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
1113
                [`~utils.ModelOutput`] types are:
1114
1115
1116
1117
1118
1119
1120
1121

                    - [`~generation_utils.GreedySearchEncoderDecoderOutput`],
                    - [`~generation_utils.SampleEncoderDecoderOutput`],
                    - [`~generation_utils.BeamSearchEncoderDecoderOutput`],
                    - [`~generation_utils.BeamSampleEncoderDecoderOutput`]

        Examples:

1122
1123
        Greedy Decoding:

1124
        ```python
1125
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM
1126

1127
1128
        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
1129

1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
        >>> prompt = "Today I believe we can finally"
        >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids

        >>> # generate up to 30 tokens
        >>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
        ```

        Multinomial Sampling:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
1146
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168

        >>> prompt = "Today I believe we can finally"
        >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids

        >>> # sample up to 30 tokens
        >>> torch.manual_seed(0)  # doctest: +IGNORE_RESULT
        >>> outputs = model.generate(input_ids, do_sample=True, max_length=30)
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the']
        ```

        Beam-search decoding:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

        >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")

        >>> sentence = "Paris is one of the densest populated areas in Europe."
        >>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids

1169
        >>> outputs = model.generate(input_ids, num_beams=5)
1170
1171
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
1172
        ```"""
1173
1174
        # 0. Validate the `.generate()` call
        self._validate_model_class()
1175
1176
        self._validate_model_kwargs(model_kwargs.copy())

1177
1178
        # 1. Set generation parameters if not already defined
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
1179
        num_beams = num_beams if num_beams is not None else self.config.num_beams
1180
1181
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
1182
        num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
1183
1184
1185
1186
1187
        do_sample = do_sample if do_sample is not None else self.config.do_sample
        num_return_sequences = (
            num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
        )

1188
1189
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1190

1191
1192
        if eos_token_id is None and hasattr(self.config, "decoder"):
            eos_token_id = self.config.decoder.eos_token_id
1193

1194
        if pad_token_id is None and eos_token_id is not None:
1195
1196
1197
1198
1199
            if model_kwargs.get("attention_mask", None) is None:
                logger.warning(
                    "The attention mask and the pad token id were not set. As a consequence, you may observe "
                    "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
                )
1200
1201
1202
            logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
            pad_token_id = eos_token_id

1203
1204
1205
1206
1207
1208
1209
1210
1211
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

1212
1213
1214
1215
1216
1217
1218
1219
1220
        # 2. Define model inputs
        # inputs_tensor has to be defined
        # model_input_name is defined if model-specific keyword input is passed
        # otherwise model_input_name is None
        # all model-specific keyword inputs are removed from `model_kwargs`
        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs)
        batch_size = inputs_tensor.shape[0]

        # 3. Define other model kwargs
1221
1222
        model_kwargs["output_attentions"] = output_attentions
        model_kwargs["output_hidden_states"] = output_hidden_states
1223
        model_kwargs["use_cache"] = use_cache
1224

1225
1226
1227
1228
        accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
        requires_attention_mask = "encoder_outputs" not in model_kwargs

        if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
1229
            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
1230
                inputs_tensor, pad_token_id, eos_token_id
1231
            )
1232

1233
        if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
1234
1235
1236
1237
1238
            # if model is encoder decoder encoder_outputs are created
            # and added to `model_kwargs`
            model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
                inputs_tensor, model_kwargs, model_input_name
            )
1239

1240
1241
1242
1243
1244
1245
1246
        # 4. Prepare `input_ids` which will be used for auto-regressive generation
        if self.config.is_encoder_decoder:
            input_ids = self._prepare_decoder_input_ids_for_generation(
                batch_size,
                decoder_start_token_id=decoder_start_token_id,
                bos_token_id=bos_token_id,
                model_kwargs=model_kwargs,
1247
                device=inputs_tensor.device,
1248
            )
1249
        else:
1250
1251
            # if decoder-only then inputs_tensor has to be `input_ids`
            input_ids = inputs_tensor
1252

1253
        # 5. Prepare `max_length` depending on other stopping criteria.
1254
        input_ids_seq_length = input_ids.shape[-1]
1255
        if max_length is None and max_new_tokens is None:
1256
            warnings.warn(
Stas Bekman's avatar
Stas Bekman committed
1257
                "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to "
1258
1259
1260
                f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is "
                "deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend "
                "using `max_new_tokens` to control the maximum length of the generation.",
1261
1262
                UserWarning,
            )
1263
1264
1265
1266
1267
1268
1269
1270
1271
        elif max_length is None and max_new_tokens is not None:
            max_length = max_new_tokens + input_ids_seq_length
        elif max_length is not None and max_new_tokens is not None:
            raise ValueError(
                "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
                " limit to the generated output length. Remove one of those arguments. Please refer to the"
                " documentation for more information. "
                "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
            )
1272
1273
        # default to config if still None
        max_length = max_length if max_length is not None else self.config.max_length
1274
        min_length = min_length if min_length is not None else self.config.min_length
1275

1276
1277
1278
1279
1280
        if min_length is not None and min_length > max_length:
            raise ValueError(
                f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
                f"length ({max_length})"
            )
1281
        if input_ids_seq_length >= max_length:
1282
1283
            input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
            logger.warning(
1284
1285
1286
                f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
                f" {max_length}. This can lead to unexpected behavior. You should consider increasing "
                "`max_new_tokens`."
1287
1288
            )

1289
        # 6. determine generation mode
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
        is_constraint_gen_mode = constraints is not None or force_words_ids is not None
        is_greedy_gen_mode = (
            (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
        )
        is_sample_gen_mode = (
            (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
        )
        is_beam_gen_mode = (
            (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
        )
1300
        is_beam_sample_gen_mode = (
1301
            (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
1302
        )
1303
        is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode
1304

1305
1306
1307
1308
1309
1310
        if num_beam_groups > num_beams:
            raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
        if is_group_beam_gen_mode and do_sample is True:
            raise ValueError(
                "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
            )
1311

1312
        # 7. prepare distribution pre_processing samplers
1313
1314
1315
        logits_processor = self._get_logits_processor(
            repetition_penalty=repetition_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
1316
            encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
1317
            input_ids_seq_length=input_ids_seq_length,
1318
            encoder_input_ids=inputs_tensor,
1319
1320
            bad_words_ids=bad_words_ids,
            min_length=min_length,
1321
            max_length=max_length,
1322
            eos_token_id=eos_token_id,
1323
1324
            forced_bos_token_id=forced_bos_token_id,
            forced_eos_token_id=forced_eos_token_id,
1325
1326
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            num_beams=num_beams,
1327
1328
            num_beam_groups=num_beam_groups,
            diversity_penalty=diversity_penalty,
1329
            remove_invalid_values=remove_invalid_values,
1330
            exponential_decay_length_penalty=exponential_decay_length_penalty,
1331
            logits_processor=logits_processor,
1332
            renormalize_logits=renormalize_logits,
1333
        )
1334

1335
        # 8. prepare stopping criteria
1336
1337
1338
        stopping_criteria = self._get_stopping_criteria(
            max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria
        )
1339

1340
        # 9. go into different generation modes
1341
1342
1343
1344
        if is_greedy_gen_mode:
            if num_return_sequences > 1:
                raise ValueError(
                    f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
1345
                )
1346

1347
            # 10. run greedy search
1348
1349
1350
            return self.greedy_search(
                input_ids,
                logits_processor=logits_processor,
1351
                stopping_criteria=stopping_criteria,
1352
1353
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1354
1355
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1356
                synced_gpus=synced_gpus,
1357
                **model_kwargs,
1358
            )
1359

1360
        elif is_sample_gen_mode:
1361
            # 10. prepare logits warper
1362
            logits_warper = self._get_logits_warper(
1363
1364
1365
1366
1367
1368
                top_k=top_k,
                top_p=top_p,
                typical_p=typical_p,
                temperature=temperature,
                num_beams=num_beams,
                renormalize_logits=renormalize_logits,
1369
1370
            )

1371
            # 11. expand input_ids with `num_return_sequences` additional sequences per batch
1372
1373
1374
1375
1376
1377
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids,
                expand_size=num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )
1378

1379
            # 12. run sample
1380
1381
1382
1383
            return self.sample(
                input_ids,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
1384
                stopping_criteria=stopping_criteria,
1385
1386
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1387
1388
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1389
                synced_gpus=synced_gpus,
1390
1391
                **model_kwargs,
            )
1392

1393
1394
1395
        elif is_beam_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
1396

1397
1398
1399
            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

1400
            # 10. prepare beam search scorer
1401
1402
1403
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=num_beams,
1404
                device=inputs_tensor.device,
1405
1406
1407
1408
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
            )
1409
            # 11. interleave input_ids with `num_beams` additional sequences per batch
1410
1411
1412
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
1413
            # 12. run beam search
1414
            return self.beam_search(
1415
                input_ids,
1416
1417
                beam_scorer,
                logits_processor=logits_processor,
1418
                stopping_criteria=stopping_criteria,
1419
1420
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1421
1422
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1423
                synced_gpus=synced_gpus,
1424
1425
1426
1427
                **model_kwargs,
            )

        elif is_beam_sample_gen_mode:
1428
            # 10. prepare logits warper
1429
            logits_warper = self._get_logits_warper(
1430
1431
1432
1433
1434
1435
                top_k=top_k,
                top_p=top_p,
                typical_p=typical_p,
                temperature=temperature,
                num_beams=num_beams,
                renormalize_logits=renormalize_logits,
1436
1437
            )

1438
1439
            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")
1440
            # 11. prepare beam search scorer
1441
            beam_scorer = BeamSearchScorer(
1442
                batch_size=batch_size * num_return_sequences,
1443
                num_beams=num_beams,
1444
                device=inputs_tensor.device,
1445
1446
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
1447
            )
1448

1449
            # 12. interleave input_ids with `num_beams` additional sequences per batch
1450
1451
1452
1453
1454
1455
1456
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids,
                expand_size=num_beams * num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )

1457
            # 13. run beam sample
1458
            return self.beam_sample(
1459
                input_ids,
1460
1461
1462
                beam_scorer,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
1463
                stopping_criteria=stopping_criteria,
1464
1465
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1466
1467
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1468
                synced_gpus=synced_gpus,
1469
                **model_kwargs,
1470
1471
            )

1472
1473
1474
1475
1476
1477
1478
        elif is_group_beam_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

            if num_beams % num_beam_groups != 0:
                raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")

1479
1480
1481
            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

1482
1483
            # 10. prepare beam search scorer
            beam_scorer = BeamSearchScorer(
1484
1485
                batch_size=batch_size,
                num_beams=num_beams,
1486
                max_length=stopping_criteria.max_length,
1487
                device=inputs_tensor.device,
1488
1489
1490
1491
1492
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
                num_beam_groups=num_beam_groups,
            )
1493
            # 11. interleave input_ids with `num_beams` additional sequences per batch
1494
1495
1496
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
1497
            # 12. run beam search
1498
1499
            return self.group_beam_search(
                input_ids,
1500
                beam_scorer,
1501
                logits_processor=logits_processor,
1502
                stopping_criteria=stopping_criteria,
1503
1504
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1505
1506
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1507
                synced_gpus=synced_gpus,
1508
1509
1510
                **model_kwargs,
            )

1511
1512
1513
1514
1515
1516
1517
1518
        elif is_constraint_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

            if num_beams <= 1:
1519
                raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")
1520
1521
1522
1523
1524
1525
1526

            if do_sample:
                raise ValueError("`do_sample` needs to be false for constrained generation.")

            if num_beam_groups is not None and num_beam_groups > 1:
                raise ValueError("`num_beam_groups` not supported yet for constrained generation.")

1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
            final_constraints = []
            if constraints is not None:
                final_constraints = constraints

            if force_words_ids is not None:

                def typeerror():
                    raise ValueError(
                        "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
                        f"of positive integers, but is {force_words_ids}."
                    )

                if not isinstance(force_words_ids, list) or len(force_words_ids) == 0:
                    typeerror()

                for word_ids in force_words_ids:
                    if isinstance(word_ids[0], list):
                        if not isinstance(word_ids, list) or len(word_ids) == 0:
                            typeerror()
                        if any(not isinstance(token_ids, list) for token_ids in word_ids):
                            typeerror()
                        if any(
                            any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
                            for token_ids in word_ids
                        ):
                            typeerror()

                        constraint = DisjunctiveConstraint(word_ids)
                    else:
                        if not isinstance(word_ids, list) or len(word_ids) == 0:
                            typeerror()
                        if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
                            typeerror()

                        constraint = PhrasalConstraint(word_ids)
                    final_constraints.append(constraint)

1564
1565
            # 10. prepare beam search scorer
            constrained_beam_scorer = ConstrainedBeamSearchScorer(
1566
                constraints=final_constraints,
1567
1568
                batch_size=batch_size,
                num_beams=num_beams,
1569
                device=inputs_tensor.device,
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
            )
            # 11. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
            # 12. run beam search
            return self.constrained_beam_search(
                input_ids,
                constrained_beam_scorer=constrained_beam_scorer,
                logits_processor=logits_processor,
                stopping_criteria=stopping_criteria,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

1592
    def greedy_search(
1593
        self,
1594
1595
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
1596
        stopping_criteria: Optional[StoppingCriteriaList] = None,
1597
1598
1599
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
1600
1601
1602
1603
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
1604
        synced_gpus: Optional[bool] = False,
1605
1606
        **model_kwargs,
    ) -> Union[GreedySearchOutput, torch.LongTensor]:
1607
        r"""
1608
1609
        Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
        used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
1610
1611

        Parameters:
1612
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1613
                The sequence used as a prompt for the generation.
1614
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1615
1616
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
1617
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1618
1619
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
1620

1621
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
1622
1623
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
1624
1625
1626
1627
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
1628
            output_attentions (`bool`, *optional*, defaults to `False`):
1629
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1630
                returned tensors for more details.
1631
            output_hidden_states (`bool`, *optional*, defaults to `False`):
1632
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1633
                for more details.
1634
            output_scores (`bool`, *optional*, defaults to `False`):
1635
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
1636
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
1637
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1638
            synced_gpus (`bool`, *optional*, defaults to `False`):
1639
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
1640
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
1641
1642
                Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
                If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
1643
1644

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1645
            [`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`]
Stas Bekman's avatar
Stas Bekman committed
1646
            or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
1647
1648
            [`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if
1649
1650
1651
1652
1653
1654
            `model.config.is_encoder_decoder=True`.

        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
1655
1656
1657
1658
        ...     AutoTokenizer,
        ...     AutoModelForCausalLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
1659
1660
        ...     StoppingCriteriaList,
        ...     MaxLengthCriteria,
1661
1662
1663
1664
1665
        ... )

        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")

1666
        >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token
1667
1668
        >>> model.config.pad_token_id = model.config.eos_token_id

1669
        >>> input_prompt = "It might be possible to"
1670
1671
1672
        >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
1673
1674
        >>> logits_processor = LogitsProcessorList(
        ...     [
1675
        ...         MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id),
Sylvain Gugger's avatar
Sylvain Gugger committed
1676
1677
        ...     ]
        ... )
1678
        >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
1679

1680
1681
1682
        >>> outputs = model.greedy_search(
        ...     input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
        ... )
1683

1684
1685
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ["It might be possible to get a better understanding of the nature of the problem, but it's not"]
1686
        ```"""
1687
1688
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1689
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1690
1691
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
1692
1693
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
1694
1695
1696
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
1697
1698
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
1711
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
1712
1713
1714
1715
1716
1717
1718
1719
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
1720

1721
1722
1723
        # keep track of which sequences are already finished
        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
        cur_len = input_ids.shape[-1]
1724

1725
        this_peer_finished = False  # used by synced_gpus only
1726
        while True:
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

1738
1739
            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1740

1741
            # forward pass to get next token
1742
1743
1744
1745
1746
1747
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
1748
1749
1750
1751
1752

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

1753
            next_token_logits = outputs.logits[:, -1, :]
1754

1755
1756
1757
            # pre-process distribution
            next_tokens_scores = logits_processor(input_ids, next_token_logits)

1758
1759
1760
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
1761
                    scores += (next_tokens_scores,)
1762
1763
1764
1765
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
1766
1767
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
1768
1769
1770
1771
1772
1773
1774
1775

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

1776
            # argmax
1777
            next_tokens = torch.argmax(next_tokens_scores, dim=-1)
1778

1779
            # finished sentences should have their next token be a padding token
1780
            if eos_token_id is not None:
1781
1782
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
1783
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
1784

1785
            # update generated ids, model inputs, and length for next step
1786
1787
1788
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1789
            )
1790
1791
            cur_len = cur_len + 1

1792
1793
1794
1795
1796
            # if eos_token was found in one sentence, set sentence to finished
            if eos_token_id is not None:
                unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

            # stop when each sentence is finished, or if we exceed the maximum length
1797
1798
1799
1800
1801
1802
            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True

1803
1804
1805
1806
1807
1808
1809
1810
        if return_dict_in_generate:
            if self.config.is_encoder_decoder:
                return GreedySearchEncoderDecoderOutput(
                    sequences=input_ids,
                    scores=scores,
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
1811
                    cross_attentions=cross_attentions,
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return GreedySearchDecoderOnlyOutput(
                    sequences=input_ids,
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return input_ids
1823
1824
1825
1826
1827

    def sample(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
1828
        stopping_criteria: Optional[StoppingCriteriaList] = None,
1829
1830
1831
1832
        logits_warper: Optional[LogitsProcessorList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
1833
1834
1835
1836
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
1837
        synced_gpus: Optional[bool] = False,
1838
1839
        **model_kwargs,
    ) -> Union[SampleOutput, torch.LongTensor]:
1840
        r"""
1841
1842
        Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
1843
1844

        Parameters:
1845
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1846
                The sequence used as a prompt for the generation.
1847
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1848
1849
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
1850
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1851
1852
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
1853
            logits_warper (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1854
1855
1856
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
                to warp the prediction score distribution of the language modeling head applied before multinomial
                sampling at each generation step.
1857
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
1858
1859
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
1860
1861
1862
1863
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
1864
            output_attentions (`bool`, *optional*, defaults to `False`):
1865
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1866
                returned tensors for more details.
1867
            output_hidden_states (`bool`, *optional*, defaults to `False`):
1868
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1869
                for more details.
1870
            output_scores (`bool`, *optional*, defaults to `False`):
1871
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
1872
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
1873
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1874
            synced_gpus (`bool`, *optional*, defaults to `False`):
1875
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
1876
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
1877
1878
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.
1879
1880

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1881
            [`~generation_utils.SampleDecoderOnlyOutput`], [`~generation_utils.SampleEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
1882
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
1883
1884
            [`~generation_utils.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.SampleEncoderDecoderOutput`] if
1885
1886
1887
1888
1889
1890
            `model.config.is_encoder_decoder=True`.

        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
1891
1892
1893
1894
1895
1896
        ...     AutoTokenizer,
        ...     AutoModelForCausalLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     TopKLogitsWarper,
        ...     TemperatureLogitsWarper,
1897
1898
        ...     StoppingCriteriaList,
        ...     MaxLengthCriteria,
1899
        ... )
1900
        >>> import torch
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911

        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")

        >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
        >>> model.config.pad_token_id = model.config.eos_token_id

        >>> input_prompt = "Today is a beautiful day, and"
        >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
1912
1913
1914
1915
1916
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )
1917
        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
1918
1919
1920
1921
1922
1923
        >>> logits_warper = LogitsProcessorList(
        ...     [
        ...         TopKLogitsWarper(50),
        ...         TemperatureLogitsWarper(0.7),
        ...     ]
        ... )
1924

1925
1926
1927
1928
1929
1930
1931
1932
1933
        >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])

        >>> torch.manual_seed(0)  # doctest: +IGNORE_RESULT
        >>> outputs = model.sample(
        ...     input_ids,
        ...     logits_processor=logits_processor,
        ...     logits_warper=logits_warper,
        ...     stopping_criteria=stopping_criteria,
        ... )
1934

1935
1936
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
1937
        ```"""
1938
1939
1940

        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1941
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1942
1943
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
1944
1945
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
1946
1947
1948
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
1949
        logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
1950
1951
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
1964
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
1965
1966
1967
1968
1969
1970
1971
1972
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
1973

1974
1975
1976
        # keep track of which sequences are already finished
        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
        cur_len = input_ids.shape[-1]
1977

1978
        this_peer_finished = False  # used by synced_gpus only
1979
        # auto-regressive generation
1980
        while True:
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

1992
1993
1994
1995
            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # forward pass to get next token
1996
1997
1998
1999
2000
2001
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
2002
2003
2004
2005
2006

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

2007
2008
2009
            next_token_logits = outputs.logits[:, -1, :]

            # pre-process distribution
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_token_scores = logits_warper(input_ids, next_token_scores)

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_token_scores,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
2021
2022
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
2023
2024
2025
2026
2027
2028
2029

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )
2030
2031

            # sample
2032
            probs = nn.functional.softmax(next_token_scores, dim=-1)
2033
2034
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

2035
            # finished sentences should have their next token be a padding token
2036
            if eos_token_id is not None:
2037
2038
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
2039
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
2040

2041
            # update generated ids, model inputs, and length for next step
2042
2043
2044
2045
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
2046
2047
            cur_len = cur_len + 1

2048
2049
2050
2051
2052
            # if eos_token was found in one sentence, set sentence to finished
            if eos_token_id is not None:
                unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

            # stop when each sentence is finished, or if we exceed the maximum length
2053
2054
2055
2056
2057
2058
            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True

2059
2060
2061
2062
2063
2064
2065
2066
        if return_dict_in_generate:
            if self.config.is_encoder_decoder:
                return SampleEncoderDecoderOutput(
                    sequences=input_ids,
                    scores=scores,
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
2067
                    cross_attentions=cross_attentions,
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return SampleDecoderOnlyOutput(
                    sequences=input_ids,
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return input_ids
2079

2080
    def beam_search(
2081
        self,
2082
2083
2084
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
2085
        stopping_criteria: Optional[StoppingCriteriaList] = None,
2086
2087
2088
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
2089
2090
2091
2092
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
2093
        synced_gpus: Optional[bool] = False,
2094
2095
        **model_kwargs,
    ) -> Union[BeamSearchOutput, torch.LongTensor]:
2096
        r"""
2097
2098
        Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
2099

2100
        Parameters:
2101
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
2102
                The sequence used as a prompt for the generation.
2103
            beam_scorer (`BeamScorer`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2104
2105
                An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
2106
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2107
2108
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
2109
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2110
2111
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
2112
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
2113
2114
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
2115
2116
2117
2118
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
2119
            output_attentions (`bool`, *optional*, defaults to `False`):
2120
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2121
                returned tensors for more details.
2122
            output_hidden_states (`bool`, *optional*, defaults to `False`):
2123
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
2124
                for more details.
2125
            output_scores (`bool`, *optional*, defaults to `False`):
2126
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
2127
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
2128
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2129
            synced_gpus (`bool`, *optional*, defaults to `False`):
2130
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
2131
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
2132
2133
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.
2134

2135
        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2136
            [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
2137
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2138
2139
            [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if
2140
2141
2142
2143
2144
2145
2146
            `model.config.is_encoder_decoder=True`.


        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
2147
2148
2149
2150
2151
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     BeamSearchScorer,
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


        >>> # lets run beam search using 3 beams
        >>> num_beams = 3
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
Sylvain Gugger's avatar
Sylvain Gugger committed
2170
2171
2172
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
        ... }

        >>> # instantiate beam scorer
        >>> beam_scorer = BeamSearchScorer(
        ...     batch_size=1,
        ...     num_beams=num_beams,
        ...     device=model.device,
        ... )

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2183
2184
2185
2186
2187
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )
2188
2189
2190

        >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)

2191
2192
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt bist du?']
2193
        ```"""
2194
2195
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
2196
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2197
2198
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
2199
2200
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
2201
2202
2203
2204
2205
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
        if len(stopping_criteria) == 0:
            warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
2206
2207
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
2208
2209
2210
2211
2212
2213
2214
2215
2216
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape

        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )

2227
2228
        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
2229
2230
2231
        beam_indices = (
            tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
        )
2232
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
2233
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
2234
2235
2236
2237
2238
2239
2240
2241
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
2242

2243
2244
2245
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view((batch_size * num_beams,))
2246

2247
        this_peer_finished = False  # used by synced_gpus only
2248
        while True:
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

2260
2261
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

2262
2263
2264
2265
2266
2267
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
2268
2269
2270
2271
2272

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

2273
            next_token_logits = outputs.logits[:, -1, :]
2274
            # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
2275
            # cannot be generated both before and after the `nn.functional.log_softmax` operation.
2276
            next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
2277
2278
2279
            next_token_scores = nn.functional.log_softmax(
                next_token_logits, dim=-1
            )  # (batch_size * num_beams, vocab_size)
2280

2281
2282
            next_token_scores_processed = logits_processor(input_ids, next_token_scores)
            next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
2283
2284
2285
2286

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
2287
                    scores += (next_token_scores_processed,)
2288
2289
2290
2291
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
2292
2293
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
2294
2295
2296
2297
2298
2299
2300
2301

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

2302
2303
2304
2305
2306
2307
            # reshape for beam search
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

            next_token_scores, next_tokens = torch.topk(
                next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
2308
2309
            )

2310
            next_indices = torch_int_div(next_tokens, vocab_size)
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
            next_tokens = next_tokens % vocab_size

            # stateless
            beam_outputs = beam_scorer.process(
                input_ids,
                next_token_scores,
                next_tokens,
                next_indices,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
2321
                beam_indices=beam_indices,
2322
            )
2323

2324
2325
2326
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]
2327

2328
            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
2329

2330
2331
2332
2333
2334
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
2335

2336
2337
2338
            if return_dict_in_generate and output_scores:
                beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

2339
2340
            # increase cur_len
            cur_len = cur_len + 1
2341

2342
2343
2344
2345
2346
            if beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True
2347

2348
        sequence_outputs = beam_scorer.finalize(
2349
2350
2351
2352
2353
2354
2355
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
2356
            beam_indices=beam_indices,
2357
        )
2358

2359
2360
2361
        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None
2362

2363
2364
2365
2366
2367
            if self.config.is_encoder_decoder:
                return BeamSearchEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2368
                    beam_indices=sequence_outputs["beam_indices"],
2369
2370
2371
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
2372
                    cross_attentions=cross_attentions,
2373
2374
2375
2376
2377
2378
2379
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return BeamSearchDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2380
                    beam_indices=sequence_outputs["beam_indices"],
2381
2382
2383
2384
2385
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]
2386

2387
2388
2389
2390
2391
    def beam_sample(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
2392
        stopping_criteria: Optional[StoppingCriteriaList] = None,
2393
2394
2395
2396
        logits_warper: Optional[LogitsProcessorList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
2397
2398
2399
2400
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
2401
        synced_gpus: Optional[bool] = False,
2402
2403
        **model_kwargs,
    ) -> Union[BeamSampleOutput, torch.LongTensor]:
2404
        r"""
2405
2406
        Generates sequences of token ids for models with a language modeling head using **beam search multinomial
        sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
2407

2408
        Parameters:
2409
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
2410
                The sequence used as a prompt for the generation.
2411
            beam_scorer (`BeamScorer`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2412
2413
                A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
2414
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2415
2416
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
2417
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2418
2419
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
2420
            logits_warper (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2421
2422
2423
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
                to warp the prediction score distribution of the language modeling head applied before multinomial
                sampling at each generation step.
2424
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
2425
2426
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
2427
2428
2429
2430
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
2431
            output_attentions (`bool`, *optional*, defaults to `False`):
2432
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2433
                returned tensors for more details.
2434
            output_hidden_states (`bool`, *optional*, defaults to `False`):
2435
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
2436
                for more details.
2437
            output_scores (`bool`, *optional*, defaults to `False`):
2438
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
2439
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
2440
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2441
            synced_gpus (`bool`, *optional*, defaults to `False`):
2442
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
2443
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
2444
2445
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.
2446

2447
        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2448
            [`~generation_utils.BeamSampleDecoderOnlyOutput`], [`~generation_utils.BeamSampleEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
2449
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2450
2451
            [`~generation_utils.BeamSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.BeamSampleEncoderDecoderOutput`] if
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
            `model.config.is_encoder_decoder=True`.

        Examples:

        ```python
        >>> from transformers import (
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     TopKLogitsWarper,
        ...     TemperatureLogitsWarper,
        ...     BeamSearchScorer,
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids

        >>> # lets run beam search using 3 beams
        >>> num_beams = 3
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
Sylvain Gugger's avatar
Sylvain Gugger committed
2482
2483
2484
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
        ... }

        >>> # instantiate beam scorer
        >>> beam_scorer = BeamSearchScorer(
        ...     batch_size=1,
        ...     max_length=model.config.max_length,
        ...     num_beams=num_beams,
        ...     device=model.device,
        ... )

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2496
2497
2498
        >>> logits_processor = LogitsProcessorList(
        ...     [MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)]
        ... )
2499
        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2500
2501
2502
2503
2504
2505
        >>> logits_warper = LogitsProcessorList(
        ...     [
        ...         TopKLogitsWarper(50),
        ...         TemperatureLogitsWarper(0.7),
        ...     ]
        ... )
2506
2507
2508
2509
2510

        >>> outputs = model.beam_sample(
        ...     input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
        ... )

2511
2512
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt bist du?']
2513
        ```"""
2514
2515
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
2516
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2517
2518
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
2519
2520
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
2521
2522
2523
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
2524
2525
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
2526
2527
2528
2529
2530
2531
2532
2533
2534
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

2535
2536
2537
2538
2539
        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape

2540
2541
        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
2542
2543
2544
        beam_indices = (
            tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
        )
2545
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
2546
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
2547
2548
2549
2550
2551
2552
2553
2554
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
2555

2556
2557
2558
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
        beam_scores = beam_scores.view((batch_size * num_beams,))

2559
        this_peer_finished = False  # used by synced_gpus only
2560
        while True:
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

2572
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2573

2574
2575
2576
2577
2578
2579
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
2580
2581
2582
2583
2584

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

2585
            next_token_logits = outputs.logits[:, -1, :]
2586

2587
            # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
2588
            # cannot be generated both before and after the `nn.functional.log_softmax` operation.
2589
            next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
2590
2591
2592
            next_token_scores = nn.functional.log_softmax(
                next_token_logits, dim=-1
            )  # (batch_size * num_beams, vocab_size)
2593

2594
2595
            next_token_scores_processed = logits_processor(input_ids, next_token_scores)
            next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
2596
            next_token_scores = logits_warper(input_ids, next_token_scores)
2597

2598
2599
2600
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
2601
                    scores += (logits_warper(input_ids, next_token_scores_processed),)
2602
2603
2604
2605
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
2606
2607
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
2608
2609
2610
2611
2612
2613
2614
2615

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

2616
2617
2618
            # reshape for beam search
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
2619

2620
            probs = nn.functional.softmax(next_token_scores, dim=-1)
2621

2622
2623
            next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
            next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
2624

2625
2626
            next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
            next_tokens = torch.gather(next_tokens, -1, _indices)
2627

2628
            next_indices = torch_int_div(next_tokens, vocab_size)
2629
            next_tokens = next_tokens % vocab_size
Sylvain Gugger's avatar
Sylvain Gugger committed
2630

2631
2632
2633
2634
2635
2636
2637
2638
            # stateless
            beam_outputs = beam_scorer.process(
                input_ids,
                next_token_scores,
                next_tokens,
                next_indices,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
2639
                beam_indices=beam_indices,
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
            )
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]

            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)

2653
2654
2655
            if return_dict_in_generate and output_scores:
                beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

2656
2657
            # increase cur_len
            cur_len = cur_len + 1
2658

2659
2660
2661
2662
2663
            if beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True
2664

2665
        sequence_outputs = beam_scorer.finalize(
2666
2667
2668
2669
2670
2671
2672
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
2673
            beam_indices=beam_indices,
2674
2675
        )

2676
2677
2678
        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None
2679

2680
            if self.config.is_encoder_decoder:
2681
                return BeamSampleEncoderDecoderOutput(
2682
2683
2684
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2685
                    beam_indices=sequence_outputs["beam_indices"],
2686
2687
2688
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
2689
                    cross_attentions=cross_attentions,
2690
2691
2692
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
2693
                return BeamSampleDecoderOnlyOutput(
2694
2695
2696
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2697
                    beam_indices=sequence_outputs["beam_indices"],
2698
2699
2700
2701
2702
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]
2703

2704
2705
2706
2707
2708
    def group_beam_search(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
2709
        stopping_criteria: Optional[StoppingCriteriaList] = None,
2710
2711
2712
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
2713
2714
2715
2716
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
2717
        synced_gpus: Optional[bool] = False,
2718
        **model_kwargs,
2719
2720
    ):
        r"""
2721
2722
        Generates sequences of token ids for models with a language modeling head using **diverse beam search
        decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
2723
2724

        Parameters:
2725
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
2726
                The sequence used as a prompt for the generation.
2727
            beam_scorer (`BeamScorer`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2728
2729
                An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
2730
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2731
2732
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
2733
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2734
2735
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
2736
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
2737
2738
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
2739
2740
2741
2742
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
2743
            output_attentions (`bool`, *optional*, defaults to `False`):
2744
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2745
                returned tensors for more details.
2746
            output_hidden_states (`bool`, *optional*, defaults to `False`):
2747
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
2748
                for more details.
2749
            output_scores (`bool`, *optional*, defaults to `False`):
2750
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
2751
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
2752
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2753
            synced_gpus (`bool`, *optional*, defaults to `False`):
2754
2755
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)

2756
            model_kwargs:
2757
2758
                Additional model specific kwargs that will be forwarded to the `forward` function of the model. If
                model is an encoder-decoder model the kwargs should include `encoder_outputs`.
2759
2760

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2761
            [`~generation_utils.BeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
2762
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2763
            [`~generation_utils.BeamSearchDecoderOnlyOutput`] if [`~generation_utils.BeamSearchDecoderOnlyOutput`] if
2764
            `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2765
            [`~generation_utils.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`.
2766
2767
2768
2769
2770

        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
2771
2772
2773
2774
2775
2776
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     HammingDiversityLogitsProcessor,
        ...     BeamSearchScorer,
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


        >>> # lets run diverse beam search using 6 beams
        >>> num_beams = 6
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
Sylvain Gugger's avatar
Sylvain Gugger committed
2795
2796
2797
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
2798
2799
2800
2801
2802
2803
2804
2805
        ... }

        >>> # instantiate beam scorer
        >>> beam_scorer = BeamSearchScorer(
        ...     batch_size=1,
        ...     max_length=model.config.max_length,
        ...     num_beams=num_beams,
        ...     device=model.device,
Sylvain Gugger's avatar
Sylvain Gugger committed
2806
        ...     num_beam_groups=3,
2807
2808
2809
        ... )

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2810
2811
2812
2813
2814
2815
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3),
        ...         MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )
2816

Sylvain Gugger's avatar
Sylvain Gugger committed
2817
2818
2819
        >>> outputs = model.group_beam_search(
        ...     input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
        ... )
2820

2821
2822
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt bist du?']
2823
        ```"""
2824
2825
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
2826
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2827
2828
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
2829
2830
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
2831
2832
2833
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
2834
2835
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
2836
2837
2838
2839
2840
2841
2842
2843
2844
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

2845
2846
2847
2848
2849
2850
2851
2852
        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams
        num_beam_groups = beam_scorer.num_beam_groups
        num_sub_beams = num_beams // num_beam_groups
        device = input_ids.device

        batch_beam_size, cur_len = input_ids.shape

2853
2854
2855
2856
2857
        if return_dict_in_generate and output_scores:
            beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
        else:
            beam_indices = None

2858
2859
2860
2861
        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )
2862

2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )

2876
2877
2878
2879
2880
2881
        beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
        # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
        # the same group don't produce same tokens everytime.
        beam_scores[:, ::num_sub_beams] = 0
        beam_scores = beam_scores.view((batch_size * num_beams,))

2882
        this_peer_finished = False  # used by synced_gpus only
2883
        while True:
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

2895
2896
2897
2898
2899
2900
2901
2902
            # predicted tokens in cur_len step
            current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)

            # indices which will form the beams in the next time step
            reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)

            # do one decoder step on all beams of all sentences in batch
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2903
2904
2905
2906
2907
2908
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
2909

2910
2911
2912
2913
            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

2914
2915
2916
            if output_scores:
                processed_score = torch.zeros_like(outputs.logits[:, -1, :])

2917
2918
2919
2920
2921
2922
2923
            for beam_group_idx in range(num_beam_groups):
                group_start_idx = beam_group_idx * num_sub_beams
                group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
                group_size = group_end_idx - group_start_idx

                # indices of beams of current group among all sentences in batch
                batch_group_indices = []
2924

2925
2926
2927
2928
2929
2930
2931
2932
2933
                for batch_idx in range(batch_size):
                    batch_group_indices.extend(
                        [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
                    )
                group_input_ids = input_ids[batch_group_indices]

                # select outputs of beams of current group only
                next_token_logits = outputs.logits[batch_group_indices, -1, :]

2934
                # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
2935
                # cannot be generated both before and after the `nn.functional.log_softmax` operation.
2936
                next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
2937
2938
2939
                next_token_scores = nn.functional.log_softmax(
                    next_token_logits, dim=-1
                )  # (batch_size * group_size, vocab_size)
2940
2941
                vocab_size = next_token_scores.shape[-1]

2942
                next_token_scores_processed = logits_processor(
2943
2944
                    group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
                )
2945
2946
                next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
                next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
2947

2948
                if output_scores:
2949
                    processed_score[batch_group_indices] = next_token_scores_processed
2950
2951

                # reshape for beam search
2952
2953
2954
2955
2956
2957
                next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)

                next_token_scores, next_tokens = torch.topk(
                    next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
                )

2958
                next_indices = torch_int_div(next_tokens, vocab_size)
2959
2960
2961
                next_tokens = next_tokens % vocab_size

                # stateless
2962
                process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
2963
2964
2965
2966
2967
2968
2969
                beam_outputs = beam_scorer.process(
                    group_input_ids,
                    next_token_scores,
                    next_tokens,
                    next_indices,
                    pad_token_id=pad_token_id,
                    eos_token_id=eos_token_id,
2970
                    beam_indices=process_beam_indices,
2971
2972
2973
2974
2975
                )
                beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
                beam_next_tokens = beam_outputs["next_beam_tokens"]
                beam_idx = beam_outputs["next_beam_indices"]

2976
2977
2978
2979
2980
                if return_dict_in_generate and output_scores:
                    beam_indices[beam_group_idx] = tuple(
                        beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
                    )

2981
2982
2983
2984
2985
2986
2987
                input_ids[batch_group_indices] = group_input_ids[beam_idx]
                group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
                current_tokens[batch_group_indices] = group_input_ids[:, -1]

                # (beam_idx // group_size) -> batch_idx
                # (beam_idx % group_size) -> offset of idx inside the group
                reordering_indices[batch_group_indices] = (
2988
                    num_beams * torch_int_div(beam_idx, group_size) + group_start_idx + (beam_idx % group_size)
2989
2990
                )

2991
2992
2993
2994
2995
2996
2997
2998
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (processed_score,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
2999
3000
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
3001
3002
3003
3004
3005
3006
3007
3008

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

3009
3010
            input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

3011
3012
3013
3014
3015
3016
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices)

3017
            # increase cur_len
3018
3019
            cur_len = cur_len + 1

3020
3021
3022
3023
3024
            if beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True
3025

3026
        final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
3027
        sequence_outputs = beam_scorer.finalize(
3028
3029
3030
3031
3032
3033
3034
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
3035
            beam_indices=final_beam_indices,
3036
3037
        )

3038
3039
        if return_dict_in_generate:
            if not output_scores:
3040
                sequence_outputs["sequence_scores"] = None
3041

3042
3043
3044
3045
3046
            if self.config.is_encoder_decoder:
                return BeamSearchEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
3047
                    beam_indices=sequence_outputs["beam_indices"],
3048
                    encoder_attentions=encoder_attentions,
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
                    cross_attentions=cross_attentions,
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return BeamSearchDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
3059
                    beam_indices=sequence_outputs["beam_indices"],
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]

    def constrained_beam_search(
        self,
        input_ids: torch.LongTensor,
        constrained_beam_scorer: ConstrainedBeamSearchScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: Optional[bool] = None,
        **model_kwargs,
    ) -> Union[BeamSearchOutput, torch.LongTensor]:

        r"""
3084
3085
        Generates sequences of token ids for models with a language modeling head using **constrained beam search
        decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119

        Parameters:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                The sequence used as a prompt for the generation.
            constrained_beam_scorer (`ConstrainedBeamSearchScorer`):
                A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation, while satisfying a list of positive constraints. For more information, the
                documentation of [`ConstrainedBeamSearchScorer`] should be read.
            logits_processor (`LogitsProcessorList`, *optional*):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
            stopping_criteria (`StoppingCriteriaList`, *optional*):
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
            logits_warper (`LogitsProcessorList`, *optional*):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
                to warp the prediction score distribution of the language modeling head applied before multinomial
                sampling at each generation step.
            max_length (`int`, *optional*, defaults to 20):
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more details.
            output_hidden_states (`bool`, *optional*, defaults to `False`):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more details.
            output_scores (`bool`, *optional*, defaults to `False`):
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
3120
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
            synced_gpus (`bool`, *optional*, defaults to `False`):
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
            model_kwargs:
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.

        Return:
            [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
            [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if
            `model.config.is_encoder_decoder=True`.


        Examples:

        ```python
        >>> from transformers import (
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     ConstrainedBeamSearchScorer,
        ...     PhrasalConstraint,
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


        >>> # lets run beam search using 3 beams
        >>> num_beams = 3
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
        ... }

3168
        >>> constraint_str = "Sie"
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
        >>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1]  # slice to remove eos token
        >>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)]


        >>> # instantiate beam scorer
        >>> beam_scorer = ConstrainedBeamSearchScorer(
        ...     batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
        ... )

        >>> # instantiate logits processors
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )

        >>> outputs = model.constrained_beam_search(
        ...     input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
        ... )

3189
3190
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt sind Sie?']
3191
3192
3193
3194
3195
3196
        ```"""
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
        if max_length is not None:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
3197
3198
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
3279
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
        if len(stopping_criteria) == 0:
            warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )

        batch_size = len(constrained_beam_scorer._beam_hyps)
        num_beams = constrained_beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape

        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )

        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view((batch_size * num_beams,))

        this_peer_finished = False  # used by synced_gpus only
        while True:

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

            next_token_logits = outputs.logits[:, -1, :]
            # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
            # cannot be generated both before and after the `nn.functional.log_softmax` operation.
            next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
            next_token_scores = nn.functional.log_softmax(
                next_token_logits, dim=-1
            )  # (batch_size * num_beams, vocab_size)

            next_token_scores_processed = logits_processor(input_ids, next_token_scores)

            next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)

3280
3281
            scores_for_all_vocab = next_token_scores.clone()

3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298
3299
3300
3301
3302
3303
3304
3305
3306
3307
3308
3309
3310
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
3356
3357
3358
3359
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_token_scores,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

            # reshape for beam search
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

            next_token_scores, next_tokens = torch.topk(
                next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
            )

            next_indices = (next_tokens / vocab_size).long()
            next_tokens = next_tokens % vocab_size

            # stateless
            beam_outputs = constrained_beam_scorer.process(
                input_ids,
                next_token_scores,
                next_tokens,
                next_indices,
                scores_for_all_vocab,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
            )
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]

            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)

            # increase cur_len
            cur_len = cur_len + 1

            if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True

        sequence_outputs = constrained_beam_scorer.finalize(
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
        )

        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None
            if self.config.is_encoder_decoder:
                return BeamSearchEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    encoder_attentions=encoder_attentions,
3360
3361
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
3362
                    cross_attentions=cross_attentions,
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return BeamSearchDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]
3375

3376

3377
def top_k_top_p_filtering(
3378
    logits: torch.FloatTensor,
3379
3380
3381
3382
    top_k: int = 0,
    top_p: float = 1.0,
    filter_value: float = -float("Inf"),
    min_tokens_to_keep: int = 1,
3383
) -> torch.FloatTensor:
Sylvain Gugger's avatar
Sylvain Gugger committed
3384
    """
3385
    Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Sylvain Gugger's avatar
Sylvain Gugger committed
3386

Lysandre's avatar
Lysandre committed
3387
3388
    Args:
        logits: logits distribution shape (batch size, vocabulary size)
3389
        top_k (`int`, *optional*, defaults to 0):
3390
            If > 0, only keep the top k tokens with highest probability (top-k filtering)
3391
        top_p (`float`, *optional*, defaults to 1.0):
3392
3393
            If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
            filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
3394
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
3395
3396
            Minimumber of tokens we keep per batch example in the output.

Lysandre's avatar
Lysandre committed
3397
    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
3398
3399
    """
    if top_k > 0:
3400
3401
3402
        logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
            None, logits
        )
3403

3404
    if 0 <= top_p <= 1.0:
3405
3406
3407
        logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
            None, logits
        )
3408

3409
    return logits