"vscode:/vscode.git/clone" did not exist on "4a51b1dd9b8ea1c77b50f0cd91fe5ae8d57ad369"
generation_utils.py 174 KB
Newer Older
1
# coding=utf-8
2
3
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

17
import inspect
18
import warnings
19
20
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
21
22

import torch
23
import torch.distributed as dist
24
from torch import nn
25

26
from .generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
27
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
28
from .generation_logits_process import (
29
    EncoderNoRepeatNGramLogitsProcessor,
30
    ExponentialDecayLengthPenalty,
31
32
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
33
    HammingDiversityLogitsProcessor,
34
    InfNanRemoveLogitsProcessor,
35
    LogitNormalization,
36
37
38
39
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    NoBadWordsLogitsProcessor,
    NoRepeatNGramLogitsProcessor,
40
    PrefixConstrainedLogitsProcessor,
41
42
43
44
    RepetitionPenaltyLogitsProcessor,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
45
    TypicalLogitsWarper,
46
)
47
48
49
from .generation_stopping_criteria import (
    MaxLengthCriteria,
    MaxTimeCriteria,
50
    StoppingCriteria,
51
52
53
    StoppingCriteriaList,
    validate_stopping_criteria,
)
54
from .pytorch_utils import torch_int_div
55
from .utils import ModelOutput, logging
56

Lysandre Debut's avatar
Lysandre Debut committed
57
58

logger = logging.get_logger(__name__)
59
60


61
62
63
64
65
66
67
@dataclass
class GreedySearchDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using greedy search.


    Args:
68
        sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
69
70
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
71
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
72
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
Sylvain Gugger's avatar
Sylvain Gugger committed
73
74
            at each generation step. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each
            tensor of shape `(batch_size, config.vocab_size)`).
75
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
76
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
77
78
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
79
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
80
            `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class GreedySearchEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention
    weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
    encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)


    Args:
98
        sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
99
100
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
101
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
102
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
Sylvain Gugger's avatar
Sylvain Gugger committed
103
104
            at each generation step. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
            `(batch_size, config.vocab_size)`).
105
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
106
107
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
            sequence_length, sequence_length)`.
108
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
109
110
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
111
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
112
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
113
114
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
115
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
116
117
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
118
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
119
            `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
120
121
122
123
124
125
126
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
127
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
128
129
130
131
132
133
134
135
136
137
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class SampleDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using sampling.


    Args:
138
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
139
140
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
141
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
142
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
Sylvain Gugger's avatar
Sylvain Gugger committed
143
144
            at each generation step. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each
            tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`).
145
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
146
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
Sylvain Gugger's avatar
Sylvain Gugger committed
147
148
            `torch.FloatTensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length,
            sequence_length)`.
149
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
150
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
151
            `torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`.
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class SampleEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of
    the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
    attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)


    Args:
169
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
170
171
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
172
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
173
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
Sylvain Gugger's avatar
Sylvain Gugger committed
174
175
            at each generation step. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
            `(batch_size*num_return_sequences, config.vocab_size)`).
176
177
178
179
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape
            `(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`.
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
180
181
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size*num_return_sequences, sequence_length, hidden_size)`.
182
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
183
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
Sylvain Gugger's avatar
Sylvain Gugger committed
184
185
            `torch.FloatTensor` of shape `(batch_size*num_return_sequences, num_heads, generated_length,
            sequence_length)`.
186
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
187
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
188
189
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
190
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
191
            `torch.FloatTensor` of shape `(batch_size*num_return_sequences, generated_length, hidden_size)`.
192
193
194
195
196
197
198
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
199
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
200
201
202
203
204
205
206
207
208
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSearchDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using beam search.

    Args:
209
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
210
211
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
212
213
214
        sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
215
216
217
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
            `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
Sylvain Gugger's avatar
Sylvain Gugger committed
218
            `(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
219
220
221
        beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
            tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
222
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
223
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
224
225
            `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
226
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
227
            `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
228
229
230
231
232
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
233
    beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
234
235
236
237
238
239
240
241
242
243
244
245
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSearchEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights
    of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
    attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)

    Args:
246
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
247
248
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
249
250
251
        sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
252
253
254
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
            `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
Sylvain Gugger's avatar
Sylvain Gugger committed
255
            config.vocab_size)`).
256
257
258
        beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
            tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
259
260
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
261
262
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
            sequence_length, sequence_length)`.
263
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
264
265
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
266
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
267
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
Sylvain Gugger's avatar
Sylvain Gugger committed
268
269
            `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
            sequence_length)`.
270
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
271
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
272
273
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
274
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
275
            `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
276
277
278
279
280
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
281
    beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
282
283
284
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
285
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
286
287
288
289
290
291
292
293
294
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSampleDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using beam sample.

    Args:
295
        sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
296
297
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
298
299
300
        sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
301
302
303
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
            `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
Sylvain Gugger's avatar
Sylvain Gugger committed
304
            `(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
305
306
307
        beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
            tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
308
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
309
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
310
311
            `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
312
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
313
            `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
314
315
316
317
318
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
319
    beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
320
321
322
323
324
325
326
327
328
329
330
331
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class BeamSampleEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention
    weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
    encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)

    Args:
332
        sequences (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
333
334
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
335
336
337
        sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
338
339
340
            Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
            of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
            `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
Sylvain Gugger's avatar
Sylvain Gugger committed
341
            config.vocab_size)`).
342
343
344
        beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
            tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
345
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
346
347
            Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
            sequence_length, sequence_length)`.
348
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
349
350
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size*num_beams, sequence_length, hidden_size)`.
351
        decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
352
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
353
354
            `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
        cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
355
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
356
357
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
358
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
359
            `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
360
361
362
363
364
    """

    sequences: torch.LongTensor = None
    sequences_scores: Optional[torch.FloatTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
365
    beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
366
367
368
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
369
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
370
371
372
373
374
375
376
377
378
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]


379
380
class GenerationMixin:
    """
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].

    The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for:
        - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
          `do_sample=False`.
        - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
          `do_sample=True`.
        - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
          `do_sample=False`.
        - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
          `num_beams>1` and `do_sample=True`.
        - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
          `num_beams>1` and `num_beam_groups>1`.
        - *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`],
          if `constraints!=None` or `force_words_ids!=None`.
396
397
    """

398
399
400
401
402
    def _prepare_model_inputs(
        self,
        inputs: Optional[torch.Tensor] = None,
        bos_token_id: Optional[int] = None,
        model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
403
    ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
404
405
406
        """
        This function extracts the model-specific `inputs` for generation.
        """
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        # 1. retrieve all kwargs that are non-None or non-model input related.
        # some encoder-decoder models have different names for model and encoder
        if (
            self.config.is_encoder_decoder
            and hasattr(self, "encoder")
            and self.encoder.main_input_name != self.main_input_name
        ):
            input_name = self.encoder.main_input_name
        else:
            input_name = self.main_input_name

        model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}

        # 2. check whether model_input_name is passed as kwarg
        # if yes and `inputs` is None use kwarg inputs
        inputs_kwarg = model_kwargs.pop(input_name, None)
        if inputs_kwarg is not None and inputs is not None:
424
425
            raise ValueError(
                f"`inputs`: {inputs}` were passed alongside "
426
427
                f"{input_name} which is not allowed."
                f"Make sure to either pass {inputs} or {input_name}=..."
428
            )
429
430
431
432
433
434
435
436
437
        elif inputs_kwarg is not None:
            inputs = inputs_kwarg

        # 3. models with `input_ids` can also make use of `inputs_embeds`
        if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs):
            inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"

        # 4. Only encoder-decoder models can have non `input_ids` input format
        if not self.config.is_encoder_decoder and input_name != "input_ids":
438
            raise ValueError(
439
440
441
                f"If {input_name} is passed as model-specific keyword "
                "input then model has to be an encoder-decoder and not a "
                f"{self.__class__.__name__}."
442
443
            )

444
445
446
447
448
449
450
451
452
453
        # 5. if `inputs` is still None, try to create `input_ids` from BOS token
        if inputs is None:
            inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))

        return inputs, input_name, model_kwargs

    def _can_retrieve_inputs_from_name(
        self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor]
    ) -> torch.Tensor:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
454
455
        If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved
        from name
456
457
458
459
460
461
462
463
464
465
        """
        can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set(
            inspect.signature(self.forward).parameters.keys()
        )

        if can_retrieve_inputs and inputs is not None:
            raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}")

        return can_retrieve_inputs

466
    def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
Sylvain Gugger's avatar
Sylvain Gugger committed
467
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
468
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
Sylvain Gugger's avatar
Sylvain Gugger committed
469
        """
470
471
        return {"input_ids": input_ids}

472
    def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
Sylvain Gugger's avatar
Sylvain Gugger committed
473
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
474
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
Sylvain Gugger's avatar
Sylvain Gugger committed
475
        """
476
477
        return logits

478
479
480
481
482
483
484
485
    def _prepare_input_ids_for_generation(
        self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput]
    ) -> torch.LongTensor:
        if self.config.is_encoder_decoder and encoder_outputs is not None:
            # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
            shape = encoder_outputs.last_hidden_state.size()[:-1]
            return torch.ones(shape, dtype=torch.long, device=self.device) * -100

486
487
488
489
490
        if bos_token_id is None:
            raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
        return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id

    def _prepare_attention_mask_for_generation(
491
        self,
492
        inputs: torch.Tensor,
493
494
        pad_token_id: int,
        eos_token_id: int,
495
    ) -> torch.LongTensor:
496
        is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
497
        is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
498
499
500
        is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
            (eos_token_id is not None) and (pad_token_id != eos_token_id)
        )
501
502
503
        # Check if input is input_ids and padded -> only then is attention_mask defined
        if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
            return inputs.ne(pad_token_id).long()
504
        else:
505
            return torch.ones(inputs.shape[:2], dtype=torch.long, device=self.device)
506
507

    def _prepare_encoder_decoder_kwargs_for_generation(
508
        self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
509
    ) -> Dict[str, Any]:
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
        # 1. get encoder
        encoder = self.get_encoder()

        # 2. prepare encoder args and encoder kwargs from model kwargs
        irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
        encoder_kwargs = {
            argument: value
            for argument, value in model_kwargs.items()
            if not any(argument.startswith(p) for p in irrelevant_prefix)
        }

        # 3. make sure that encoder returns `ModelOutput`
        model_input_name = model_input_name if model_input_name is not None else self.main_input_name
        encoder_kwargs["return_dict"] = True
        encoder_kwargs[model_input_name] = inputs_tensor
        model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
526

527
528
529
        return model_kwargs

    def _prepare_decoder_input_ids_for_generation(
530
531
532
533
534
        self,
        batch_size: int,
        decoder_start_token_id: int = None,
        bos_token_id: int = None,
        model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
535
    ) -> torch.LongTensor:
536

537
538
539
540
541
        if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
            return model_kwargs.pop("decoder_input_ids")
        else:
            decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
            return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573

    def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
        decoder_start_token_id = (
            decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
        )
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id

        if decoder_start_token_id is not None:
            return decoder_start_token_id
        elif (
            hasattr(self.config, "decoder")
            and hasattr(self.config.decoder, "decoder_start_token_id")
            and self.config.decoder.decoder_start_token_id is not None
        ):
            return self.config.decoder.decoder_start_token_id
        elif bos_token_id is not None:
            return bos_token_id
        elif (
            hasattr(self.config, "decoder")
            and hasattr(self.config.decoder, "bos_token_id")
            and self.config.decoder.bos_token_id is not None
        ):
            return self.config.decoder.bos_token_id
        raise ValueError(
            "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
        )

    @staticmethod
    def _expand_inputs_for_generation(
        input_ids: torch.LongTensor,
        expand_size: int = 1,
        is_encoder_decoder: bool = False,
574
575
        attention_mask: Optional[torch.LongTensor] = None,
        encoder_outputs: Optional[ModelOutput] = None,
576
        **model_kwargs,
577
578
579
580
581
582
    ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
        expanded_return_idx = (
            torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
        )
        input_ids = input_ids.index_select(0, expanded_return_idx)

583
584
585
586
        if "token_type_ids" in model_kwargs:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)

587
588
589
590
        if attention_mask is not None:
            model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)

        if is_encoder_decoder:
591
592
            if encoder_outputs is None:
                raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
593
            encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
594
                0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
            )
            model_kwargs["encoder_outputs"] = encoder_outputs
        return input_ids, model_kwargs

    @staticmethod
    def _update_model_kwargs_for_generation(
        outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
    ) -> Dict[str, Any]:
        # update past
        if "past_key_values" in outputs:
            model_kwargs["past"] = outputs.past_key_values
        elif "mems" in outputs:
            model_kwargs["past"] = outputs.mems
        elif "past_buckets_states" in outputs:
            model_kwargs["past"] = outputs.past_buckets_states
        else:
            model_kwargs["past"] = None

613
614
615
616
617
        # update token_type_ids with last value
        if "token_type_ids" in model_kwargs:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)

618
619
620
621
622
623
624
625
626
627
        # update attention mask
        if not is_encoder_decoder:
            if "attention_mask" in model_kwargs:
                attention_mask = model_kwargs["attention_mask"]
                model_kwargs["attention_mask"] = torch.cat(
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
                )

        return model_kwargs

628
629
630
631
    def _reorder_cache(self, past, beam_idx):
        raise NotImplementedError(
            f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}"
        )
632

633
    def _get_logits_warper(
634
        self,
635
636
637
638
639
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        typical_p: Optional[float] = None,
        temperature: Optional[float] = None,
        num_beams: Optional[int] = None,
640
        renormalize_logits: Optional[bool] = None,
641
642
    ) -> LogitsProcessorList:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
643
644
        This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
        used for multinomial sampling.
645
        """
646

647
648
649
        # init warp parameters
        top_k = top_k if top_k is not None else self.config.top_k
        top_p = top_p if top_p is not None else self.config.top_p
650
        typical_p = typical_p if typical_p is not None else self.config.typical_p
651
652
653
654
655
656
        temperature = temperature if temperature is not None else self.config.temperature
        # instantiate warpers list
        warpers = LogitsProcessorList()

        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
657
658
        if temperature is not None and temperature != 1.0:
            warpers.append(TemperatureLogitsWarper(temperature))
659
660
661
662
        if top_k is not None and top_k != 0:
            warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
        if top_p is not None and top_p < 1.0:
            warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
663
664
        if typical_p is not None and typical_p < 1.0:
            warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
665
666
667
        # `LogitNormalization` should always be the last logit processor, when present
        if renormalize_logits is True:
            warpers.append(LogitNormalization())
668
669
670
671
672
673
        return warpers

    def _get_logits_processor(
        self,
        repetition_penalty: float,
        no_repeat_ngram_size: int,
674
        encoder_no_repeat_ngram_size: int,
675
        input_ids_seq_length: int,
676
        encoder_input_ids: torch.LongTensor,
677
678
        bad_words_ids: List[List[int]],
        min_length: int,
679
        max_length: int,
680
        eos_token_id: int,
681
682
        forced_bos_token_id: int,
        forced_eos_token_id: int,
683
684
        prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
        num_beams: int,
685
686
        num_beam_groups: int,
        diversity_penalty: float,
687
        remove_invalid_values: bool,
688
        exponential_decay_length_penalty: Tuple,
689
        logits_processor: Optional[LogitsProcessorList],
690
        renormalize_logits: Optional[bool],
691
692
    ) -> LogitsProcessorList:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
693
694
        This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
        instances used to modify the scores of the language model head.
695
        """
696
        processors = LogitsProcessorList()
697

698
699
700
701
702
        # init warp parameters
        repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
        no_repeat_ngram_size = (
            no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
        )
703
704
705
706
707
        encoder_no_repeat_ngram_size = (
            encoder_no_repeat_ngram_size
            if encoder_no_repeat_ngram_size is not None
            else self.config.encoder_no_repeat_ngram_size
        )
708
709
        bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
710
        diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
711
712
713
714
715
716
        forced_bos_token_id = (
            forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
        )
        forced_eos_token_id = (
            forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
        )
717
718
719
        remove_invalid_values = (
            remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
        )
720
721
722
723
724
        exponential_decay_length_penalty = (
            exponential_decay_length_penalty
            if exponential_decay_length_penalty is not None
            else self.config.exponential_decay_length_penalty
        )
725
726
727
728
        # instantiate processors list

        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
729
730
731
732
733
734
        if diversity_penalty is not None and diversity_penalty > 0.0:
            processors.append(
                HammingDiversityLogitsProcessor(
                    diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups
                )
            )
735
736
737
738
        if repetition_penalty is not None and repetition_penalty != 1.0:
            processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
        if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
            processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
739
740
741
742
743
744
745
        if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0:
            if self.config.is_encoder_decoder:
                processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids))
            else:
                raise ValueError(
                    "It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture"
                )
746
        if bad_words_ids is not None:
747
            processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
748
        if min_length is not None and eos_token_id is not None and min_length > 0:
749
            processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
750
        if prefix_allowed_tokens_fn is not None:
751
            processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups))
752
753
754
755
        if forced_bos_token_id is not None:
            processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
        if forced_eos_token_id is not None:
            processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
756
757
        if remove_invalid_values is True:
            processors.append(InfNanRemoveLogitsProcessor())
758
759
760
761
        if exponential_decay_length_penalty is not None:
            processors.append(
                ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
            )
762
        processors = self._merge_criteria_processor_list(processors, logits_processor)
763
764
765
        # `LogitNormalization` should always be the last logit processor, when present
        if renormalize_logits is True:
            processors.append(LogitNormalization())
766
        return processors
767

768
769
770
771
    def _get_stopping_criteria(
        self, max_length: Optional[int], max_time: Optional[float], stopping_criteria: Optional[StoppingCriteriaList]
    ) -> StoppingCriteriaList:
        criteria = StoppingCriteriaList()
772
        if max_length is not None:
773
            criteria.append(MaxLengthCriteria(max_length=max_length))
774
        if max_time is not None:
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
            criteria.append(MaxTimeCriteria(max_time=max_time))
        criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
        return criteria

    def _merge_criteria_processor_list(
        self,
        default_list: Union[LogitsProcessorList, StoppingCriteriaList],
        custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
    ) -> Union[LogitsProcessorList, StoppingCriteriaList]:
        if len(custom_list) == 0:
            return default_list
        for default in default_list:
            for custom in custom_list:
                if type(custom) is type(default):
                    object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
                    raise ValueError(
                        f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to `generate`, "
                        f"but it has already been created with the values {default}. {default} has been created by passing the "
                        "corresponding arguments to generate or by the model's config default values. "
                        f"If you just want to change the default values of {object_type} consider passing them as arguments "
                        f"to `generate` instead of using a custom {object_type}."
                    )
        default_list.extend(custom_list)
        return default_list
799

800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
    def compute_transition_beam_scores(
        self,
        sequences: torch.Tensor,
        scores: Tuple[torch.Tensor],
        beam_indices: torch.Tensor,
        eos_token_id: int = None,
    ):
        """compute the transition probabilities of sequences given generation
        scores and beam indices"""

        # reshape scores as [vocab_size * batch_size, # generation steps]
        # with batch_size being 2 * vocab_size and # generation steps being
        # seq_len - input_length
        scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)

        # start of generated tokens
        cut_idx = sequences.shape[-1] - scores.shape[-1]
        # adjust for beam indices
        beam_sequence_indices = torch.tensor(beam_indices, device=sequences.device) * self.config.vocab_size
        # compute real indices
        indices = sequences[:, cut_idx:] + beam_sequence_indices
        # gather scores and run
        transition_scores = scores.gather(0, indices)
        # make sure that if EOS token was used before length of sequence `sequence.shape[-1]`
        # get first occurence of EOS token
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id

        if eos_token_id is not None:
            is_eos_token_id = sequences[:, cut_idx:] == eos_token_id
            # make sure first eos token still contributes to transition probs
            is_eos_token_id[:, -1] = False
            is_eos_token_id = is_eos_token_id.roll(1, -1)
            # all indices after eos shoud be masked
            zero_transition_prob_mask = is_eos_token_id.cumsum(-1).bool()
            # zero out padded probs
            transition_scores.masked_fill_(zero_transition_prob_mask, 0.0)

        return transition_scores

839
840
841
    @torch.no_grad()
    def generate(
        self,
842
        inputs: Optional[torch.Tensor] = None,
843
844
845
846
847
848
849
850
        max_length: Optional[int] = None,
        min_length: Optional[int] = None,
        do_sample: Optional[bool] = None,
        early_stopping: Optional[bool] = None,
        num_beams: Optional[int] = None,
        temperature: Optional[float] = None,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
851
        typical_p: Optional[float] = None,
852
853
        repetition_penalty: Optional[float] = None,
        bad_words_ids: Optional[Iterable[int]] = None,
854
        force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
855
856
857
858
859
        bos_token_id: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        no_repeat_ngram_size: Optional[int] = None,
860
        encoder_no_repeat_ngram_size: Optional[int] = None,
861
        num_return_sequences: Optional[int] = None,
862
        max_time: Optional[float] = None,
863
        max_new_tokens: Optional[int] = None,
864
865
        decoder_start_token_id: Optional[int] = None,
        use_cache: Optional[bool] = None,
866
867
        num_beam_groups: Optional[int] = None,
        diversity_penalty: Optional[float] = None,
868
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
869
        logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
870
        renormalize_logits: Optional[bool] = None,
871
        stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
872
        constraints: Optional[List[Constraint]] = None,
873
874
875
876
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
877
878
        forced_bos_token_id: Optional[int] = None,
        forced_eos_token_id: Optional[int] = None,
879
        remove_invalid_values: Optional[bool] = None,
880
        synced_gpus: Optional[bool] = False,
881
        exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
882
883
        **model_kwargs,
    ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
Sylvain Gugger's avatar
Sylvain Gugger committed
884
        r"""
885

886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
        Generates sequences of token ids for models with a language modeling head. The method supports the following
        generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:

            - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
              `do_sample=False`.
            - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
              `do_sample=True`.
            - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
              `do_sample=False`.
            - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
              `num_beams>1` and `do_sample=True`.
            - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
              `num_beams>1` and `num_beam_groups>1`.
            - *constrained beam-search decoding* by calling
              [`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or
              `force_words_ids!=None`.

        <Tip warning={true}>

        Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
        defined in the model's config (`config.json`) which in turn defaults to the
        [`~modeling_utils.PretrainedConfig`] of the model.

        </Tip>
910

Sylvain Gugger's avatar
Sylvain Gugger committed
911
912
        Most of these parameters are explained in more detail in [this blog
        post](https://huggingface.co/blog/how-to-generate).
913
914

        Parameters:
915
            inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
916
                The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
Sylvain Gugger's avatar
Sylvain Gugger committed
917
918
919
                method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
                should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
                `input_ids`, `input_values`, `input_features`, or `pixel_values`.
920
            max_length (`int`, *optional*, defaults to `model.config.max_length`):
Sylvain Gugger's avatar
Sylvain Gugger committed
921
                The maximum length of the sequence to be generated.
922
            max_new_tokens (`int`, *optional*, defaults to None):
923
                The maximum numbers of tokens to generate, ignore the current number of tokens. Use either
924
925
                `max_new_tokens` or `max_length` but not both, they serve the same purpose.
            min_length (`int`, *optional*, defaults to 10):
Sylvain Gugger's avatar
Sylvain Gugger committed
926
                The minimum length of the sequence to be generated.
927
            do_sample (`bool`, *optional*, defaults to `False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
928
                Whether or not to use sampling ; use greedy decoding otherwise.
929
930
931
            early_stopping (`bool`, *optional*, defaults to `False`):
                Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
            num_beams (`int`, *optional*, defaults to 1):
Sylvain Gugger's avatar
Sylvain Gugger committed
932
                Number of beams for beam search. 1 means no beam search.
933
            temperature (`float`, *optional*, defaults to 1.0):
Sylvain Gugger's avatar
Sylvain Gugger committed
934
                The value used to module the next token probabilities.
935
            top_k (`int`, *optional*, defaults to 50):
Sylvain Gugger's avatar
Sylvain Gugger committed
936
                The number of highest probability vocabulary tokens to keep for top-k-filtering.
937
            top_p (`float`, *optional*, defaults to 1.0):
Sylvain Gugger's avatar
Sylvain Gugger committed
938
939
                If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
                are kept for generation.
940
            repetition_penalty (`float`, *optional*, defaults to 1.0):
Sylvain Gugger's avatar
Sylvain Gugger committed
941
942
                The parameter for repetition penalty. 1.0 means no penalty. See [this
                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
943
944
945
946
947
948
949
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            bos_token_id (`int`, *optional*):
                The id of the *beginning-of-sequence* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
            length_penalty (`float`, *optional*, defaults to 1.0):
950
951
952
                Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
                model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
                sequences.
953
            no_repeat_ngram_size (`int`, *optional*, defaults to 0):
Sylvain Gugger's avatar
Sylvain Gugger committed
954
                If set to int > 0, all ngrams of that size can only occur once.
955
956
957
958
            encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0):
                If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
                `decoder_input_ids`.
            bad_words_ids(`List[List[int]]`, *optional*):
959
960
                List of token ids that are not allowed to be generated. In order to get the token ids of the words that
                should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
961
                add_special_tokens=False).input_ids`.
962
963
964
965
966
            force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
                List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple
                list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`,
                this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081),
                where one can allow different forms of each word.
967
            num_return_sequences(`int`, *optional*, defaults to 1):
Sylvain Gugger's avatar
Sylvain Gugger committed
968
                The number of independently computed returned sequences for each element in the batch.
969
            max_time(`float`, *optional*, defaults to None):
970
971
                The maximum amount of time you allow the computation to run for in seconds. generation will still
                finish the current pass after allocated time has been passed.
972
            attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
973
974
975
                Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens
                that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape
                as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask)
976
977
978
            decoder_start_token_id (`int`, *optional*):
                If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
            use_cache: (`bool`, *optional*, defaults to `True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
979
980
                Whether or not the model should use the past last key/values attentions (if applicable to the model) to
                speed up decoding.
981
982
983
984
            num_beam_groups (`int`, *optional*, defaults to 1):
                Number of groups to divide `num_beams` into in order to ensure diversity among different groups of
                beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
            diversity_penalty (`float`, *optional*, defaults to 0.0):
985
                This value is subtracted from a beam's score if it generates a token same as any beam from other group
986
                at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is
987
                enabled.
988
            prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
989
                If provided, this function constraints the beam search to allowed tokens only at each step. If not
990
                provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
Sylvain Gugger's avatar
Sylvain Gugger committed
991
992
993
994
                `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
                on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
                for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
                Retrieval](https://arxiv.org/abs/2010.00904).
995
            logits_processor (`LogitsProcessorList`, *optional*):
996
997
998
                 Custom logits processors that complement the default logits processors built from arguments and a
                 model's config. If a logit processor is passed that is already created with the arguments or a model's
                 config an error is thrown. This feature is intended for advanced users.
999
1000
1001
1002
            renormalize_logits: (`bool`, *optional*, defaults to `False`):
                Whether to renormalize the logits after applying all the logits processors or warpers (including the
                custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
                score logits are normalized but some logit processors or warpers break the normalization.
1003
            stopping_criteria (`StoppingCriteriaList`, *optional*):
1004
1005
1006
                 Custom stopping criteria that complement the default stopping criteria built from arguments and a
                 model's config. If a stopping criteria is passed that is already created with the arguments or a
                 model's config an error is thrown. This feature is intended for advanced users.
1007
1008
1009
            constraints (`List[Constraint]`, *optional*):
                 Custom constraints that can be added to the generation to ensure that the output will contain the use
                 of certain tokens as defined by `Constraint` objects, in the most sensible way possible.
1010
            output_attentions (`bool`, *optional*, defaults to `False`):
1011
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1012
                returned tensors for more details.
1013
            output_hidden_states (`bool`, *optional*, defaults to `False`):
1014
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1015
                for more details.
1016
            output_scores (`bool`, *optional*, defaults to `False`):
1017
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
1018
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
1019
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1020
            forced_bos_token_id (`int`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1021
1022
1023
                The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful
                for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be
                the target language token.
1024
1025
1026
1027
1028
1029
            forced_eos_token_id (`int`, *optional*):
                The id of the token to force as the last generated token when `max_length` is reached.
            remove_invalid_values (`bool`, *optional*):
                Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to
                crash. Note that using `remove_invalid_values` can slow down generation.
            synced_gpus (`bool`, *optional*, defaults to `False`):
1030
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
1031
1032
1033
1034
1035
            exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
                This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
                generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates
                where penalty starts and `decay_factor` represents the factor of exponential decay

1036
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
1037
1038
1039
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
                is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
                should be prefixed with *decoder_*.
1040
1041

        Return:
1042
1043
            [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
            or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
1044

Sylvain Gugger's avatar
Sylvain Gugger committed
1045
                If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
1046
                [`~utils.ModelOutput`] types are:
1047
1048
1049
1050
1051
1052
1053

                    - [`~generation_utils.GreedySearchDecoderOnlyOutput`],
                    - [`~generation_utils.SampleDecoderOnlyOutput`],
                    - [`~generation_utils.BeamSearchDecoderOnlyOutput`],
                    - [`~generation_utils.BeamSampleDecoderOnlyOutput`]

                If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
1054
                [`~utils.ModelOutput`] types are:
1055
1056
1057
1058
1059
1060
1061
1062

                    - [`~generation_utils.GreedySearchEncoderDecoderOutput`],
                    - [`~generation_utils.SampleEncoderDecoderOutput`],
                    - [`~generation_utils.BeamSearchEncoderDecoderOutput`],
                    - [`~generation_utils.BeamSampleEncoderDecoderOutput`]

        Examples:

1063
1064
        Greedy Decoding:

1065
        ```python
1066
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM
1067

1068
1069
        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
1070

1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
        >>> prompt = "Today I believe we can finally"
        >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids

        >>> # generate up to 30 tokens
        >>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
        ```

        Multinomial Sampling:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
1087
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112

        >>> prompt = "Today I believe we can finally"
        >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids

        >>> # sample up to 30 tokens
        >>> torch.manual_seed(0)  # doctest: +IGNORE_RESULT
        >>> outputs = model.generate(input_ids, do_sample=True, max_length=30)
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the']
        ```

        Beam-search decoding:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

        >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")

        >>> sentence = "Paris is one of the densest populated areas in Europe."
        >>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids

        >>> outputs = model.generate(input_ids)
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
1113
        ```"""
1114
1115
        # 1. Set generation parameters if not already defined
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
1116
        num_beams = num_beams if num_beams is not None else self.config.num_beams
1117
1118
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
1119
        num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
1120
1121
1122
1123
1124
        do_sample = do_sample if do_sample is not None else self.config.do_sample
        num_return_sequences = (
            num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
        )

1125
1126
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1127

1128
1129
        if eos_token_id is None and hasattr(self.config, "decoder"):
            eos_token_id = self.config.decoder.eos_token_id
1130

1131
1132
1133
1134
1135
        if pad_token_id is None and eos_token_id is not None:
            # special case if pad_token_id is not defined
            logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
            pad_token_id = eos_token_id

1136
1137
1138
1139
1140
1141
1142
1143
1144
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

1145
1146
1147
1148
1149
1150
1151
1152
1153
        # 2. Define model inputs
        # inputs_tensor has to be defined
        # model_input_name is defined if model-specific keyword input is passed
        # otherwise model_input_name is None
        # all model-specific keyword inputs are removed from `model_kwargs`
        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs)
        batch_size = inputs_tensor.shape[0]

        # 3. Define other model kwargs
1154
1155
        model_kwargs["output_attentions"] = output_attentions
        model_kwargs["output_hidden_states"] = output_hidden_states
1156
        model_kwargs["use_cache"] = use_cache
1157

1158
1159
1160
1161
        accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
        requires_attention_mask = "encoder_outputs" not in model_kwargs

        if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
1162
            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
1163
                inputs_tensor, pad_token_id, eos_token_id
1164
            )
1165

1166
        if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
1167
1168
1169
1170
1171
            # if model is encoder decoder encoder_outputs are created
            # and added to `model_kwargs`
            model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
                inputs_tensor, model_kwargs, model_input_name
            )
1172

1173
1174
1175
1176
1177
1178
1179
1180
        # 4. Prepare `input_ids` which will be used for auto-regressive generation
        if self.config.is_encoder_decoder:
            input_ids = self._prepare_decoder_input_ids_for_generation(
                batch_size,
                decoder_start_token_id=decoder_start_token_id,
                bos_token_id=bos_token_id,
                model_kwargs=model_kwargs,
            )
1181
        else:
1182
1183
            # if decoder-only then inputs_tensor has to be `input_ids`
            input_ids = inputs_tensor
1184

1185
1186
        input_ids_seq_length = input_ids.shape[-1]

1187
        # 5. Prepare `max_length` depending on other stopping criteria
1188
1189
        # if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens`
        if max_length is None and max_new_tokens is not None:
1190
            max_length = max_new_tokens + input_ids_seq_length
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
        elif max_length is not None and max_new_tokens is not None:
            # Both are set, this is odd, raise a warning
            warnings.warn(
                "Both `max_length` and `max_new_tokens` have been set "
                f"but they serve the same purpose. `max_length` {max_length} "
                f"will take priority over `max_new_tokens` {max_new_tokens}.",
                UserWarning,
            )
        # default to config if still None
        max_length = max_length if max_length is not None else self.config.max_length
1201
        min_length = min_length if min_length is not None else self.config.min_length
1202

1203
1204
1205
1206
1207
        if min_length is not None and min_length > max_length:
            raise ValueError(
                f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
                f"length ({max_length})"
            )
1208
        if input_ids_seq_length >= max_length:
1209
1210
            input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
            logger.warning(
1211
                f"Input length of {input_ids_string} is {input_ids_seq_length}, but ``max_length`` is set to {max_length}. "
1212
1213
1214
                "This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``."
            )

1215
        # 6. determine generation mode
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
        is_constraint_gen_mode = constraints is not None or force_words_ids is not None
        is_greedy_gen_mode = (
            (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
        )
        is_sample_gen_mode = (
            (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
        )
        is_beam_gen_mode = (
            (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
        )
1226
        is_beam_sample_gen_mode = (
1227
            (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
1228
        )
1229
        is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode
1230

1231
1232
1233
1234
1235
1236
        if num_beam_groups > num_beams:
            raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
        if is_group_beam_gen_mode and do_sample is True:
            raise ValueError(
                "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
            )
1237

1238
        # 7. prepare distribution pre_processing samplers
1239
1240
1241
        logits_processor = self._get_logits_processor(
            repetition_penalty=repetition_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
1242
            encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
1243
            input_ids_seq_length=input_ids_seq_length,
1244
            encoder_input_ids=inputs_tensor,
1245
1246
            bad_words_ids=bad_words_ids,
            min_length=min_length,
1247
            max_length=max_length,
1248
            eos_token_id=eos_token_id,
1249
1250
            forced_bos_token_id=forced_bos_token_id,
            forced_eos_token_id=forced_eos_token_id,
1251
1252
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            num_beams=num_beams,
1253
1254
            num_beam_groups=num_beam_groups,
            diversity_penalty=diversity_penalty,
1255
            remove_invalid_values=remove_invalid_values,
1256
            exponential_decay_length_penalty=exponential_decay_length_penalty,
1257
            logits_processor=logits_processor,
1258
            renormalize_logits=renormalize_logits,
1259
        )
1260

1261
        # 8. prepare stopping criteria
1262
1263
1264
        stopping_criteria = self._get_stopping_criteria(
            max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria
        )
1265

1266
        # 9. go into different generation modes
1267
1268
1269
1270
        if is_greedy_gen_mode:
            if num_return_sequences > 1:
                raise ValueError(
                    f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
1271
                )
1272

1273
            # 10. run greedy search
1274
1275
1276
            return self.greedy_search(
                input_ids,
                logits_processor=logits_processor,
1277
                stopping_criteria=stopping_criteria,
1278
1279
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1280
1281
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1282
                synced_gpus=synced_gpus,
1283
                **model_kwargs,
1284
            )
1285

1286
        elif is_sample_gen_mode:
1287
            # 10. prepare logits warper
1288
            logits_warper = self._get_logits_warper(
1289
1290
1291
1292
1293
1294
                top_k=top_k,
                top_p=top_p,
                typical_p=typical_p,
                temperature=temperature,
                num_beams=num_beams,
                renormalize_logits=renormalize_logits,
1295
1296
            )

1297
            # 11. expand input_ids with `num_return_sequences` additional sequences per batch
1298
1299
1300
1301
1302
1303
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids,
                expand_size=num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )
1304

1305
            # 12. run sample
1306
1307
1308
1309
            return self.sample(
                input_ids,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
1310
                stopping_criteria=stopping_criteria,
1311
1312
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1313
1314
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1315
                synced_gpus=synced_gpus,
1316
1317
                **model_kwargs,
            )
1318

1319
1320
1321
        elif is_beam_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
1322

1323
1324
1325
            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

1326
            # 10. prepare beam search scorer
1327
1328
1329
1330
1331
1332
1333
1334
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=num_beams,
                device=self.device,
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
            )
1335
            # 11. interleave input_ids with `num_beams` additional sequences per batch
1336
1337
1338
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
1339
            # 12. run beam search
1340
            return self.beam_search(
1341
                input_ids,
1342
1343
                beam_scorer,
                logits_processor=logits_processor,
1344
                stopping_criteria=stopping_criteria,
1345
1346
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1347
1348
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1349
                synced_gpus=synced_gpus,
1350
1351
1352
1353
                **model_kwargs,
            )

        elif is_beam_sample_gen_mode:
1354
            # 10. prepare logits warper
1355
            logits_warper = self._get_logits_warper(
1356
1357
1358
1359
1360
1361
                top_k=top_k,
                top_p=top_p,
                typical_p=typical_p,
                temperature=temperature,
                num_beams=num_beams,
                renormalize_logits=renormalize_logits,
1362
1363
            )

1364
1365
            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")
1366
            # 11. prepare beam search scorer
1367
            beam_scorer = BeamSearchScorer(
1368
                batch_size=batch_size * num_return_sequences,
1369
                num_beams=num_beams,
1370
1371
1372
                device=self.device,
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
1373
            )
1374

1375
            # 12. interleave input_ids with `num_beams` additional sequences per batch
1376
1377
1378
1379
1380
1381
1382
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids,
                expand_size=num_beams * num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )

1383
            # 13. run beam sample
1384
            return self.beam_sample(
1385
                input_ids,
1386
1387
1388
                beam_scorer,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
1389
                stopping_criteria=stopping_criteria,
1390
1391
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1392
1393
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1394
                synced_gpus=synced_gpus,
1395
                **model_kwargs,
1396
1397
            )

1398
1399
1400
1401
1402
1403
1404
        elif is_group_beam_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

            if num_beams % num_beam_groups != 0:
                raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")

1405
1406
1407
            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

1408
1409
            # 10. prepare beam search scorer
            beam_scorer = BeamSearchScorer(
1410
1411
                batch_size=batch_size,
                num_beams=num_beams,
1412
                max_length=stopping_criteria.max_length,
1413
1414
1415
1416
1417
1418
                device=self.device,
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
                num_beam_groups=num_beam_groups,
            )
1419
            # 11. interleave input_ids with `num_beams` additional sequences per batch
1420
1421
1422
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
1423
            # 12. run beam search
1424
1425
            return self.group_beam_search(
                input_ids,
1426
                beam_scorer,
1427
                logits_processor=logits_processor,
1428
                stopping_criteria=stopping_criteria,
1429
1430
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
1431
1432
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
1433
                synced_gpus=synced_gpus,
1434
1435
1436
                **model_kwargs,
            )

1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
        elif is_constraint_gen_mode:
            if num_return_sequences > num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

            if num_beams <= 1:
                raise ValueError("`num_beams` needs to be greater than 1 for constrained genertation.")

            if do_sample:
                raise ValueError("`do_sample` needs to be false for constrained generation.")

            if num_beam_groups is not None and num_beam_groups > 1:
                raise ValueError("`num_beam_groups` not supported yet for constrained generation.")

1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
            final_constraints = []
            if constraints is not None:
                final_constraints = constraints

            if force_words_ids is not None:

                def typeerror():
                    raise ValueError(
                        "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
                        f"of positive integers, but is {force_words_ids}."
                    )

                if not isinstance(force_words_ids, list) or len(force_words_ids) == 0:
                    typeerror()

                for word_ids in force_words_ids:
                    if isinstance(word_ids[0], list):
                        if not isinstance(word_ids, list) or len(word_ids) == 0:
                            typeerror()
                        if any(not isinstance(token_ids, list) for token_ids in word_ids):
                            typeerror()
                        if any(
                            any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
                            for token_ids in word_ids
                        ):
                            typeerror()

                        constraint = DisjunctiveConstraint(word_ids)
                    else:
                        if not isinstance(word_ids, list) or len(word_ids) == 0:
                            typeerror()
                        if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
                            typeerror()

                        constraint = PhrasalConstraint(word_ids)
                    final_constraints.append(constraint)

1490
1491
            # 10. prepare beam search scorer
            constrained_beam_scorer = ConstrainedBeamSearchScorer(
1492
                constraints=final_constraints,
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
                batch_size=batch_size,
                num_beams=num_beams,
                device=self.device,
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
            )
            # 11. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
            )
            # 12. run beam search
            return self.constrained_beam_search(
                input_ids,
                constrained_beam_scorer=constrained_beam_scorer,
                logits_processor=logits_processor,
                stopping_criteria=stopping_criteria,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                output_scores=output_scores,
                return_dict_in_generate=return_dict_in_generate,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

1518
    def greedy_search(
1519
        self,
1520
1521
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
1522
        stopping_criteria: Optional[StoppingCriteriaList] = None,
1523
1524
1525
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
1526
1527
1528
1529
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
1530
        synced_gpus: Optional[bool] = False,
1531
1532
        **model_kwargs,
    ) -> Union[GreedySearchOutput, torch.LongTensor]:
1533
        r"""
1534
1535
        Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
        used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
1536
1537
1538

        Parameters:

1539
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1540
                The sequence used as a prompt for the generation.
1541
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1542
1543
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
1544
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1545
1546
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
1547

1548
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
1549
1550
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
1551
1552
1553
1554
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
1555
            output_attentions (`bool`, *optional*, defaults to `False`):
1556
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1557
                returned tensors for more details.
1558
            output_hidden_states (`bool`, *optional*, defaults to `False`):
1559
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1560
                for more details.
1561
            output_scores (`bool`, *optional*, defaults to `False`):
1562
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
1563
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
1564
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1565
            synced_gpus (`bool`, *optional*, defaults to `False`):
1566
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
1567
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
1568
1569
                Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
                If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
1570
1571

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1572
            [`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`]
Stas Bekman's avatar
Stas Bekman committed
1573
            or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
1574
1575
            [`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if
1576
1577
1578
1579
1580
1581
            `model.config.is_encoder_decoder=True`.

        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
1582
1583
1584
1585
        ...     AutoTokenizer,
        ...     AutoModelForCausalLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
1586
1587
        ...     StoppingCriteriaList,
        ...     MaxLengthCriteria,
1588
1589
1590
1591
1592
1593
1594
1595
        ... )

        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")

        >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
        >>> model.config.pad_token_id = model.config.eos_token_id

1596
        >>> input_prompt = "It might be possible to"
1597
1598
1599
        >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
1600
1601
        >>> logits_processor = LogitsProcessorList(
        ...     [
1602
        ...         MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id),
Sylvain Gugger's avatar
Sylvain Gugger committed
1603
1604
        ...     ]
        ... )
1605
        >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
1606

1607
1608
1609
        >>> outputs = model.greedy_search(
        ...     input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
        ... )
1610

1611
1612
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ["It might be possible to get a better understanding of the nature of the problem, but it's not"]
1613
        ```"""
1614
1615
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1616
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1617
1618
        if max_length is not None:
            warnings.warn(
1619
                "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
1620
1621
1622
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
1623
1624
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
1637
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
1638
1639
1640
1641
1642
1643
1644
1645
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
1646

1647
1648
1649
        # keep track of which sequences are already finished
        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
        cur_len = input_ids.shape[-1]
1650

1651
        this_peer_finished = False  # used by synced_gpus only
1652
        while True:
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

1664
1665
            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1666

1667
            # forward pass to get next token
1668
1669
1670
1671
1672
1673
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
1674
1675
1676
1677
1678

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

1679
            next_token_logits = outputs.logits[:, -1, :]
1680

1681
1682
1683
1684
1685
1686
1687
1688
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_token_logits,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
1689
1690
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
1691
1692
1693
1694
1695
1696
1697
1698

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

1699
            # pre-process distribution
1700
            next_tokens_scores = logits_processor(input_ids, next_token_logits)
1701
1702

            # argmax
1703
            next_tokens = torch.argmax(next_tokens_scores, dim=-1)
1704

1705
            # finished sentences should have their next token be a padding token
1706
            if eos_token_id is not None:
1707
1708
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
1709
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
1710

1711
            # update generated ids, model inputs, and length for next step
1712
1713
1714
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1715
            )
1716
1717
            cur_len = cur_len + 1

1718
1719
1720
1721
1722
            # if eos_token was found in one sentence, set sentence to finished
            if eos_token_id is not None:
                unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

            # stop when each sentence is finished, or if we exceed the maximum length
1723
1724
1725
1726
1727
1728
            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True

1729
1730
1731
1732
1733
1734
1735
1736
        if return_dict_in_generate:
            if self.config.is_encoder_decoder:
                return GreedySearchEncoderDecoderOutput(
                    sequences=input_ids,
                    scores=scores,
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
1737
                    cross_attentions=cross_attentions,
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return GreedySearchDecoderOnlyOutput(
                    sequences=input_ids,
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return input_ids
1749
1750
1751
1752
1753

    def sample(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
1754
        stopping_criteria: Optional[StoppingCriteriaList] = None,
1755
1756
1757
1758
        logits_warper: Optional[LogitsProcessorList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
1759
1760
1761
1762
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
1763
        synced_gpus: Optional[bool] = False,
1764
1765
        **model_kwargs,
    ) -> Union[SampleOutput, torch.LongTensor]:
1766
        r"""
1767
1768
        Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
1769
1770
1771

        Parameters:

1772
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1773
                The sequence used as a prompt for the generation.
1774
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1775
1776
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
1777
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1778
1779
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
1780
            logits_warper (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1781
1782
1783
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
                to warp the prediction score distribution of the language modeling head applied before multinomial
                sampling at each generation step.
1784
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
1785
1786
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
1787
1788
1789
1790
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
1791
            output_attentions (`bool`, *optional*, defaults to `False`):
1792
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1793
                returned tensors for more details.
1794
            output_hidden_states (`bool`, *optional*, defaults to `False`):
1795
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1796
                for more details.
1797
            output_scores (`bool`, *optional*, defaults to `False`):
1798
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
1799
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
1800
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1801
            synced_gpus (`bool`, *optional*, defaults to `False`):
1802
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
1803
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
1804
1805
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.
1806
1807

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1808
            [`~generation_utils.SampleDecoderOnlyOutput`], [`~generation_utils.SampleEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
1809
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
1810
1811
            [`~generation_utils.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.SampleEncoderDecoderOutput`] if
1812
1813
1814
1815
1816
1817
            `model.config.is_encoder_decoder=True`.

        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
1818
1819
1820
1821
1822
1823
        ...     AutoTokenizer,
        ...     AutoModelForCausalLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     TopKLogitsWarper,
        ...     TemperatureLogitsWarper,
1824
1825
        ...     StoppingCriteriaList,
        ...     MaxLengthCriteria,
1826
        ... )
1827
        >>> import torch
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838

        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")

        >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
        >>> model.config.pad_token_id = model.config.eos_token_id

        >>> input_prompt = "Today is a beautiful day, and"
        >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
1839
1840
1841
1842
1843
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )
1844
        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
1845
1846
1847
1848
1849
1850
        >>> logits_warper = LogitsProcessorList(
        ...     [
        ...         TopKLogitsWarper(50),
        ...         TemperatureLogitsWarper(0.7),
        ...     ]
        ... )
1851

1852
1853
1854
1855
1856
1857
1858
1859
1860
        >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])

        >>> torch.manual_seed(0)  # doctest: +IGNORE_RESULT
        >>> outputs = model.sample(
        ...     input_ids,
        ...     logits_processor=logits_processor,
        ...     logits_warper=logits_warper,
        ...     stopping_criteria=stopping_criteria,
        ... )
1861

1862
1863
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
1864
        ```"""
1865
1866
1867

        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1868
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1869
1870
1871
1872
1873
1874
        if max_length is not None:
            warnings.warn(
                "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
1875
        logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
1876
1877
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
1890
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
1891
1892
1893
1894
1895
1896
1897
1898
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
1899

1900
1901
1902
        # keep track of which sequences are already finished
        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
        cur_len = input_ids.shape[-1]
1903

1904
        this_peer_finished = False  # used by synced_gpus only
1905
        # auto-regressive generation
1906
        while True:
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

1918
1919
1920
1921
            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # forward pass to get next token
1922
1923
1924
1925
1926
1927
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
1928
1929
1930
1931
1932

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

1933
1934
1935
            next_token_logits = outputs.logits[:, -1, :]

            # pre-process distribution
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_token_scores = logits_warper(input_ids, next_token_scores)

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_token_scores,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
1947
1948
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
1949
1950
1951
1952
1953
1954
1955

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )
1956
1957

            # sample
1958
            probs = nn.functional.softmax(next_token_scores, dim=-1)
1959
1960
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

1961
            # finished sentences should have their next token be a padding token
1962
            if eos_token_id is not None:
1963
1964
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
1965
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
1966

1967
            # update generated ids, model inputs, and length for next step
1968
1969
1970
1971
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
1972
1973
            cur_len = cur_len + 1

1974
1975
1976
1977
1978
            # if eos_token was found in one sentence, set sentence to finished
            if eos_token_id is not None:
                unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

            # stop when each sentence is finished, or if we exceed the maximum length
1979
1980
1981
1982
1983
1984
            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True

1985
1986
1987
1988
1989
1990
1991
1992
        if return_dict_in_generate:
            if self.config.is_encoder_decoder:
                return SampleEncoderDecoderOutput(
                    sequences=input_ids,
                    scores=scores,
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
1993
                    cross_attentions=cross_attentions,
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return SampleDecoderOnlyOutput(
                    sequences=input_ids,
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return input_ids
2005

2006
    def beam_search(
2007
        self,
2008
2009
2010
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
2011
        stopping_criteria: Optional[StoppingCriteriaList] = None,
2012
2013
2014
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
2015
2016
2017
2018
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
2019
        synced_gpus: Optional[bool] = False,
2020
2021
        **model_kwargs,
    ) -> Union[BeamSearchOutput, torch.LongTensor]:
2022
        r"""
2023
2024
        Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
2025

2026
        Parameters:
2027

2028
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
2029
                The sequence used as a prompt for the generation.
2030
            beam_scorer (`BeamScorer`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2031
2032
                An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
2033
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2034
2035
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
2036
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2037
2038
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
2039
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
2040
2041
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
2042
2043
2044
2045
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
2046
            output_attentions (`bool`, *optional*, defaults to `False`):
2047
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2048
                returned tensors for more details.
2049
            output_hidden_states (`bool`, *optional*, defaults to `False`):
2050
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
2051
                for more details.
2052
            output_scores (`bool`, *optional*, defaults to `False`):
2053
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
2054
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
2055
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2056
            synced_gpus (`bool`, *optional*, defaults to `False`):
2057
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
2058
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
2059
2060
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.
2061

2062
        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2063
            [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
2064
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2065
2066
            [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if
2067
2068
2069
2070
2071
2072
2073
            `model.config.is_encoder_decoder=True`.


        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
2074
2075
2076
2077
2078
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     BeamSearchScorer,
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


        >>> # lets run beam search using 3 beams
        >>> num_beams = 3
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
Sylvain Gugger's avatar
Sylvain Gugger committed
2097
2098
2099
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
        ... }

        >>> # instantiate beam scorer
        >>> beam_scorer = BeamSearchScorer(
        ...     batch_size=1,
        ...     num_beams=num_beams,
        ...     device=model.device,
        ... )

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2110
2111
2112
2113
2114
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )
2115
2116
2117

        >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)

2118
2119
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt bist du?']
2120
        ```"""
2121
2122
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
2123
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2124
2125
2126
2127
2128
2129
2130
2131
        if max_length is not None:
            warnings.warn(
                "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
        if len(stopping_criteria) == 0:
            warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
2132
2133
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
2134
2135
2136
2137
2138
2139
2140
2141
2142
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape

        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )

2153
2154
        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
2155
2156
2157
        beam_indices = (
            tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
        )
2158
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
2159
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
2160
2161
2162
2163
2164
2165
2166
2167
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
2168

2169
2170
2171
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view((batch_size * num_beams,))
2172

2173
        this_peer_finished = False  # used by synced_gpus only
2174
        while True:
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

2186
2187
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

2188
2189
2190
2191
2192
2193
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
2194
2195
2196
2197
2198

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

2199
            next_token_logits = outputs.logits[:, -1, :]
2200
            # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
2201
            # cannot be generated both before and after the `nn.functional.log_softmax` operation.
2202
            next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
2203
2204
2205
            next_token_scores = nn.functional.log_softmax(
                next_token_logits, dim=-1
            )  # (batch_size * num_beams, vocab_size)
2206

2207
2208
            next_token_scores_processed = logits_processor(input_ids, next_token_scores)
            next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
2209
2210
2211
2212

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
2213
                    scores += (next_token_scores_processed,)
2214
2215
2216
2217
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
2218
2219
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
2220
2221
2222
2223
2224
2225
2226
2227

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

2228
2229
2230
2231
2232
2233
            # reshape for beam search
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

            next_token_scores, next_tokens = torch.topk(
                next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
2234
2235
            )

2236
            next_indices = torch_int_div(next_tokens, vocab_size)
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
            next_tokens = next_tokens % vocab_size

            # stateless
            beam_outputs = beam_scorer.process(
                input_ids,
                next_token_scores,
                next_tokens,
                next_indices,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
2247
            )
2248

2249
2250
2251
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]
2252

2253
            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
2254

2255
2256
2257
2258
2259
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
2260

2261
2262
2263
            if return_dict_in_generate and output_scores:
                beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

2264
2265
            # increase cur_len
            cur_len = cur_len + 1
2266

2267
2268
2269
2270
2271
            if beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True
2272

2273
        sequence_outputs = beam_scorer.finalize(
2274
2275
2276
2277
2278
2279
2280
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
2281
        )
2282

2283
2284
2285
        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None
2286
2287
2288
2289
2290
2291
2292
2293
            else:
                num_return_sequences = beam_scorer.num_beam_hyps_to_keep
                # return only as many indices as sequences
                beam_indices = tuple(
                    (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
                )
                beam_indices = sum(beam_indices, ())

2294
2295
2296
2297
2298
            if self.config.is_encoder_decoder:
                return BeamSearchEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2299
                    beam_indices=beam_indices,
2300
2301
2302
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
2303
                    cross_attentions=cross_attentions,
2304
2305
2306
2307
2308
2309
2310
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return BeamSearchDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2311
                    beam_indices=beam_indices,
2312
2313
2314
2315
2316
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]
2317

2318
2319
2320
2321
2322
    def beam_sample(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
2323
        stopping_criteria: Optional[StoppingCriteriaList] = None,
2324
2325
2326
2327
        logits_warper: Optional[LogitsProcessorList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
2328
2329
2330
2331
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
2332
        synced_gpus: Optional[bool] = False,
2333
2334
        **model_kwargs,
    ) -> Union[BeamSampleOutput, torch.LongTensor]:
2335
        r"""
2336
2337
        Generates sequences of token ids for models with a language modeling head using **beam search multinomial
        sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
2338

2339
        Parameters:
2340

2341
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
2342
                The sequence used as a prompt for the generation.
2343
            beam_scorer (`BeamScorer`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2344
2345
                A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
2346
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2347
2348
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
2349
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2350
2351
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
2352
            logits_warper (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2353
2354
2355
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
                to warp the prediction score distribution of the language modeling head applied before multinomial
                sampling at each generation step.
2356
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
2357
2358
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
2359
2360
2361
2362
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
2363
            output_attentions (`bool`, *optional*, defaults to `False`):
2364
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2365
                returned tensors for more details.
2366
            output_hidden_states (`bool`, *optional*, defaults to `False`):
2367
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
2368
                for more details.
2369
            output_scores (`bool`, *optional*, defaults to `False`):
2370
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
2371
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
2372
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2373
            synced_gpus (`bool`, *optional*, defaults to `False`):
2374
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
2375
            model_kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
2376
2377
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.
2378

2379
        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2380
            [`~generation_utils.BeamSampleDecoderOnlyOutput`], [`~generation_utils.BeamSampleEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
2381
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2382
2383
            [`~generation_utils.BeamSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.BeamSampleEncoderDecoderOutput`] if
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
            `model.config.is_encoder_decoder=True`.

        Examples:

        ```python
        >>> from transformers import (
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     TopKLogitsWarper,
        ...     TemperatureLogitsWarper,
        ...     BeamSearchScorer,
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids

        >>> # lets run beam search using 3 beams
        >>> num_beams = 3
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
Sylvain Gugger's avatar
Sylvain Gugger committed
2414
2415
2416
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
        ... }

        >>> # instantiate beam scorer
        >>> beam_scorer = BeamSearchScorer(
        ...     batch_size=1,
        ...     max_length=model.config.max_length,
        ...     num_beams=num_beams,
        ...     device=model.device,
        ... )

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2428
2429
2430
        >>> logits_processor = LogitsProcessorList(
        ...     [MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)]
        ... )
2431
        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2432
2433
2434
2435
2436
2437
        >>> logits_warper = LogitsProcessorList(
        ...     [
        ...         TopKLogitsWarper(50),
        ...         TemperatureLogitsWarper(0.7),
        ...     ]
        ... )
2438
2439
2440
2441
2442

        >>> outputs = model.beam_sample(
        ...     input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
        ... )

2443
2444
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt bist du?']
2445
        ```"""
2446
2447
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
2448
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2449
2450
2451
2452
2453
2454
        if max_length is not None:
            warnings.warn(
                "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
2455
2456
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
2457
2458
2459
2460
2461
2462
2463
2464
2465
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

2466
2467
2468
2469
2470
        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape

2471
2472
        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
2473
2474
2475
        beam_indices = (
            tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
        )
2476
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
2477
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
2478
2479
2480
2481
2482
2483
2484
2485
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )
2486

2487
2488
2489
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
        beam_scores = beam_scores.view((batch_size * num_beams,))

2490
        this_peer_finished = False  # used by synced_gpus only
2491
        while True:
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

2503
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2504

2505
2506
2507
2508
2509
2510
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
2511
2512
2513
2514
2515

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

2516
            next_token_logits = outputs.logits[:, -1, :]
2517

2518
            # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
2519
            # cannot be generated both before and after the `nn.functional.log_softmax` operation.
2520
            next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
2521
2522
2523
            next_token_scores = nn.functional.log_softmax(
                next_token_logits, dim=-1
            )  # (batch_size * num_beams, vocab_size)
2524

2525
2526
            next_token_scores_processed = logits_processor(input_ids, next_token_scores)
            next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
2527
            next_token_scores = logits_warper(input_ids, next_token_scores)
2528

2529
2530
2531
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
2532
                    scores += (logits_warper(input_ids, next_token_scores_processed),)
2533
2534
2535
2536
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
2537
2538
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
2539
2540
2541
2542
2543
2544
2545
2546

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

2547
2548
2549
            # reshape for beam search
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
2550

2551
            probs = nn.functional.softmax(next_token_scores, dim=-1)
2552

2553
2554
            next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
            next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
2555

2556
2557
            next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
            next_tokens = torch.gather(next_tokens, -1, _indices)
2558

2559
            next_indices = torch_int_div(next_tokens, vocab_size)
2560
            next_tokens = next_tokens % vocab_size
Sylvain Gugger's avatar
Sylvain Gugger committed
2561

2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
            # stateless
            beam_outputs = beam_scorer.process(
                input_ids,
                next_token_scores,
                next_tokens,
                next_indices,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
            )
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]

            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)

2583
2584
2585
            if return_dict_in_generate and output_scores:
                beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

2586
2587
            # increase cur_len
            cur_len = cur_len + 1
2588

2589
2590
2591
2592
2593
            if beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True
2594

2595
        sequence_outputs = beam_scorer.finalize(
2596
2597
2598
2599
2600
2601
2602
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
2603
2604
        )

2605
2606
2607
        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None
2608
2609
2610
2611
2612
2613
2614
2615
            else:
                num_return_sequences = beam_scorer.num_beam_hyps_to_keep
                # return only as many indices as sequences
                beam_indices = tuple(
                    (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
                )
                beam_indices = sum(beam_indices, ())

2616
            if self.config.is_encoder_decoder:
2617
                return BeamSampleEncoderDecoderOutput(
2618
2619
2620
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2621
                    beam_indices=beam_indices,
2622
2623
2624
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
2625
                    cross_attentions=cross_attentions,
2626
2627
2628
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
2629
                return BeamSampleDecoderOnlyOutput(
2630
2631
2632
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2633
                    beam_indices=beam_indices,
2634
2635
2636
2637
2638
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]
2639

2640
2641
2642
2643
2644
    def group_beam_search(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
2645
        stopping_criteria: Optional[StoppingCriteriaList] = None,
2646
2647
2648
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
2649
2650
2651
2652
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
2653
        synced_gpus: Optional[bool] = False,
2654
        **model_kwargs,
2655
2656
    ):
        r"""
2657
2658
        Generates sequences of token ids for models with a language modeling head using **diverse beam search
        decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
2659
2660
2661

        Parameters:

2662
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
2663
                The sequence used as a prompt for the generation.
2664
            beam_scorer (`BeamScorer`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2665
2666
                An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
2667
            logits_processor (`LogitsProcessorList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2668
2669
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
2670
            stopping_criteria (`StoppingCriteriaList`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2671
2672
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
2673
            max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
2674
2675
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
2676
2677
2678
2679
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
2680
            output_attentions (`bool`, *optional*, defaults to `False`):
2681
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2682
                returned tensors for more details.
2683
            output_hidden_states (`bool`, *optional*, defaults to `False`):
2684
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
2685
                for more details.
2686
            output_scores (`bool`, *optional*, defaults to `False`):
2687
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
2688
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
2689
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2690
            synced_gpus (`bool`, *optional*, defaults to `False`):
2691
2692
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)

2693
            model_kwargs:
2694
2695
                Additional model specific kwargs that will be forwarded to the `forward` function of the model. If
                model is an encoder-decoder model the kwargs should include `encoder_outputs`.
2696
2697

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
2698
            [`~generation_utils.BeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
Stas Bekman's avatar
Stas Bekman committed
2699
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2700
            [`~generation_utils.BeamSearchDecoderOnlyOutput`] if [`~generation_utils.BeamSearchDecoderOnlyOutput`] if
2701
            `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a
Sylvain Gugger's avatar
Sylvain Gugger committed
2702
            [`~generation_utils.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`.
2703
2704
2705
2706
2707

        Examples:

        ```python
        >>> from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
2708
2709
2710
2711
2712
2713
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     HammingDiversityLogitsProcessor,
        ...     BeamSearchScorer,
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


        >>> # lets run diverse beam search using 6 beams
        >>> num_beams = 6
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
Sylvain Gugger's avatar
Sylvain Gugger committed
2732
2733
2734
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
2735
2736
2737
2738
2739
2740
2741
2742
        ... }

        >>> # instantiate beam scorer
        >>> beam_scorer = BeamSearchScorer(
        ...     batch_size=1,
        ...     max_length=model.config.max_length,
        ...     num_beams=num_beams,
        ...     device=model.device,
Sylvain Gugger's avatar
Sylvain Gugger committed
2743
        ...     num_beam_groups=3,
2744
2745
2746
        ... )

        >>> # instantiate logits processors
Sylvain Gugger's avatar
Sylvain Gugger committed
2747
2748
2749
2750
2751
2752
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3),
        ...         MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )
2753

Sylvain Gugger's avatar
Sylvain Gugger committed
2754
2755
2756
        >>> outputs = model.group_beam_search(
        ...     input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
        ... )
2757

2758
2759
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt bist du?']
2760
        ```"""
2761
2762
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
2763
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2764
2765
2766
2767
2768
2769
        if max_length is not None:
            warnings.warn(
                "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
2770
2771
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
2772
2773
2774
2775
2776
2777
2778
2779
2780
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

2781
2782
2783
2784
2785
2786
2787
2788
        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams
        num_beam_groups = beam_scorer.num_beam_groups
        num_sub_beams = num_beams // num_beam_groups
        device = input_ids.device

        batch_beam_size, cur_len = input_ids.shape

2789
2790
2791
2792
2793
        if return_dict_in_generate and output_scores:
            beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
        else:
            beam_indices = None

2794
2795
2796
2797
        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )
2798

2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )

2812
2813
2814
2815
2816
2817
        beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
        # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
        # the same group don't produce same tokens everytime.
        beam_scores[:, ::num_sub_beams] = 0
        beam_scores = beam_scores.view((batch_size * num_beams,))

2818
        this_peer_finished = False  # used by synced_gpus only
2819
        while True:
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

2831
2832
2833
2834
2835
2836
2837
2838
            # predicted tokens in cur_len step
            current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)

            # indices which will form the beams in the next time step
            reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)

            # do one decoder step on all beams of all sentences in batch
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2839
2840
2841
2842
2843
2844
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
2845

2846
2847
2848
2849
            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

2850
2851
2852
            if output_scores:
                processed_score = torch.zeros_like(outputs.logits[:, -1, :])

2853
2854
2855
2856
2857
2858
2859
            for beam_group_idx in range(num_beam_groups):
                group_start_idx = beam_group_idx * num_sub_beams
                group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
                group_size = group_end_idx - group_start_idx

                # indices of beams of current group among all sentences in batch
                batch_group_indices = []
2860

2861
2862
2863
2864
2865
2866
2867
2868
2869
                for batch_idx in range(batch_size):
                    batch_group_indices.extend(
                        [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
                    )
                group_input_ids = input_ids[batch_group_indices]

                # select outputs of beams of current group only
                next_token_logits = outputs.logits[batch_group_indices, -1, :]

2870
                # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
2871
                # cannot be generated both before and after the `nn.functional.log_softmax` operation.
2872
                next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
2873
2874
2875
                next_token_scores = nn.functional.log_softmax(
                    next_token_logits, dim=-1
                )  # (batch_size * group_size, vocab_size)
2876
2877
                vocab_size = next_token_scores.shape[-1]

2878
                next_token_scores_processed = logits_processor(
2879
2880
                    group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
                )
2881
2882
                next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
                next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
2883

2884
                if output_scores:
2885
                    processed_score[batch_group_indices] = next_token_scores_processed
2886
2887

                # reshape for beam search
2888
2889
2890
2891
2892
2893
                next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)

                next_token_scores, next_tokens = torch.topk(
                    next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
                )

2894
                next_indices = torch_int_div(next_tokens, vocab_size)
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
                next_tokens = next_tokens % vocab_size

                # stateless
                beam_outputs = beam_scorer.process(
                    group_input_ids,
                    next_token_scores,
                    next_tokens,
                    next_indices,
                    pad_token_id=pad_token_id,
                    eos_token_id=eos_token_id,
                )
                beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
                beam_next_tokens = beam_outputs["next_beam_tokens"]
                beam_idx = beam_outputs["next_beam_indices"]

2910
2911
2912
2913
2914
                if return_dict_in_generate and output_scores:
                    beam_indices[beam_group_idx] = tuple(
                        beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
                    )

2915
2916
2917
2918
2919
2920
2921
                input_ids[batch_group_indices] = group_input_ids[beam_idx]
                group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
                current_tokens[batch_group_indices] = group_input_ids[:, -1]

                # (beam_idx // group_size) -> batch_idx
                # (beam_idx % group_size) -> offset of idx inside the group
                reordering_indices[batch_group_indices] = (
2922
                    num_beams * torch_int_div(beam_idx, group_size) + group_start_idx + (beam_idx % group_size)
2923
2924
                )

2925
2926
2927
2928
2929
2930
2931
2932
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (processed_score,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
2933
2934
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
2935
2936
2937
2938
2939
2940
2941
2942

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

2943
2944
            input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

2945
2946
2947
2948
2949
2950
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices)

2951
            # increase cur_len
2952
2953
            cur_len = cur_len + 1

2954
2955
2956
2957
2958
            if beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True
2959

2960
        sequence_outputs = beam_scorer.finalize(
2961
2962
2963
2964
2965
2966
2967
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
2968
2969
        )

2970
2971
        if return_dict_in_generate:
            if not output_scores:
2972
                sequence_outputs["sequence_scores"] = None
2973
2974
2975
2976
2977
2978
2979
2980
2981
            else:
                beam_indices = sum(beam_indices, ())
                num_return_sequences = beam_scorer.num_beam_hyps_to_keep
                # return only as many indices as sequences
                beam_indices = tuple(
                    (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
                )
                beam_indices = sum(beam_indices, ())

2982
2983
2984
2985
2986
            if self.config.is_encoder_decoder:
                return BeamSearchEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
2987
                    beam_indices=beam_indices,
2988
                    encoder_attentions=encoder_attentions,
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
                    cross_attentions=cross_attentions,
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return BeamSearchDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]

    def constrained_beam_search(
        self,
        input_ids: torch.LongTensor,
        constrained_beam_scorer: ConstrainedBeamSearchScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: Optional[bool] = None,
        **model_kwargs,
    ) -> Union[BeamSearchOutput, torch.LongTensor]:

        r"""
3023
3024
        Generates sequences of token ids for models with a language modeling head using **constrained beam search
        decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058

        Parameters:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                The sequence used as a prompt for the generation.
            constrained_beam_scorer (`ConstrainedBeamSearchScorer`):
                A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation, while satisfying a list of positive constraints. For more information, the
                documentation of [`ConstrainedBeamSearchScorer`] should be read.
            logits_processor (`LogitsProcessorList`, *optional*):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
            stopping_criteria (`StoppingCriteriaList`, *optional*):
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
            logits_warper (`LogitsProcessorList`, *optional*):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
                to warp the prediction score distribution of the language modeling head applied before multinomial
                sampling at each generation step.
            max_length (`int`, *optional*, defaults to 20):
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more details.
            output_hidden_states (`bool`, *optional*, defaults to `False`):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more details.
            output_scores (`bool`, *optional*, defaults to `False`):
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
3059
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
            synced_gpus (`bool`, *optional*, defaults to `False`):
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
            model_kwargs:
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.

        Return:
            [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
            [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if
            `model.config.is_encoder_decoder=True`.


        Examples:

        ```python
        >>> from transformers import (
        ...     AutoTokenizer,
        ...     AutoModelForSeq2SeqLM,
        ...     LogitsProcessorList,
        ...     MinLengthLogitsProcessor,
        ...     ConstrainedBeamSearchScorer,
        ...     PhrasalConstraint,
        ... )
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

        >>> encoder_input_str = "translate English to German: How old are you?"
        >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


        >>> # lets run beam search using 3 beams
        >>> num_beams = 3
        >>> # define decoder start token ids
        >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
        >>> input_ids = input_ids * model.config.decoder_start_token_id

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
        ... }

3107
        >>> constraint_str = "Sie"
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
        >>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1]  # slice to remove eos token
        >>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)]


        >>> # instantiate beam scorer
        >>> beam_scorer = ConstrainedBeamSearchScorer(
        ...     batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
        ... )

        >>> # instantiate logits processors
        >>> logits_processor = LogitsProcessorList(
        ...     [
        ...         MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
        ...     ]
        ... )

        >>> outputs = model.constrained_beam_search(
        ...     input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
        ... )

3128
3129
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ['Wie alt sind Sie?']
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
        ```"""
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
        if max_length is not None:
            warnings.warn(
                "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
        if len(stopping_criteria) == 0:
            warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
        output_scores = output_scores if output_scores is not None else self.config.output_scores
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
        )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )

        batch_size = len(constrained_beam_scorer._beam_hyps)
        num_beams = constrained_beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape

        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )

        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view((batch_size * num_beams,))

        this_peer_finished = False  # used by synced_gpus only
        while True:

            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue  # don't waste resources running the code we don't need

            next_token_logits = outputs.logits[:, -1, :]
            # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
            # cannot be generated both before and after the `nn.functional.log_softmax` operation.
            next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
            next_token_scores = nn.functional.log_softmax(
                next_token_logits, dim=-1
            )  # (batch_size * num_beams, vocab_size)

            next_token_scores_processed = logits_processor(input_ids, next_token_scores)

            scores_for_all_vocab = next_token_scores_processed.clone()

            next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_token_scores,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

            # reshape for beam search
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

            next_token_scores, next_tokens = torch.topk(
                next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
            )

            next_indices = (next_tokens / vocab_size).long()
            next_tokens = next_tokens % vocab_size

            # stateless
            beam_outputs = constrained_beam_scorer.process(
                input_ids,
                next_token_scores,
                next_tokens,
                next_indices,
                scores_for_all_vocab,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
            )
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]

            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            if model_kwargs["past"] is not None:
                model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)

            # increase cur_len
            cur_len = cur_len + 1

            if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True

        sequence_outputs = constrained_beam_scorer.finalize(
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
        )

        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None
            if self.config.is_encoder_decoder:
                return BeamSearchEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    encoder_attentions=encoder_attentions,
3298
3299
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
3300
                    cross_attentions=cross_attentions,
3301
3302
3303
3304
3305
3306
3307
3308
3309
3310
3311
3312
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return BeamSearchDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]
3313

3314

3315
def top_k_top_p_filtering(
3316
    logits: torch.FloatTensor,
3317
3318
3319
3320
    top_k: int = 0,
    top_p: float = 1.0,
    filter_value: float = -float("Inf"),
    min_tokens_to_keep: int = 1,
3321
) -> torch.FloatTensor:
Sylvain Gugger's avatar
Sylvain Gugger committed
3322
    """
3323
    Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Sylvain Gugger's avatar
Sylvain Gugger committed
3324

Lysandre's avatar
Lysandre committed
3325
3326
    Args:
        logits: logits distribution shape (batch size, vocabulary size)
3327
        top_k (`int`, *optional*, defaults to 0):
3328
            If > 0, only keep the top k tokens with highest probability (top-k filtering)
3329
        top_p (`float`, *optional*, defaults to 1.0):
3330
3331
            If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
            filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
3332
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
3333
3334
            Minimumber of tokens we keep per batch example in the output.

Lysandre's avatar
Lysandre committed
3335
    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
3336
3337
    """
    if top_k > 0:
3338
3339
3340
        logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
            None, logits
        )
3341

3342
3343
    if 0 <= top_p <= 1.0:
        logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits)
3344

3345
    return logits