generation_flax_utils.py 38.1 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Copyright 2021 The Google AI Flax Team Authors, and The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


18
from functools import partial
Patrick von Platen's avatar
Patrick von Platen committed
19
20
from typing import Dict, Optional

21
22
import numpy as np

Patrick von Platen's avatar
Patrick von Platen committed
23
24
25
26
27
28
import flax
import jax
import jax.numpy as jnp
from jax import lax

from .generation_flax_logits_process import (
29
30
    FlaxForcedBOSTokenLogitsProcessor,
    FlaxForcedEOSTokenLogitsProcessor,
Patrick von Platen's avatar
Patrick von Platen committed
31
    FlaxLogitsProcessorList,
32
    FlaxMinLengthLogitsProcessor,
Patrick von Platen's avatar
Patrick von Platen committed
33
34
35
36
    FlaxTemperatureLogitsWarper,
    FlaxTopKLogitsWarper,
    FlaxTopPLogitsWarper,
)
37
from .utils import ModelOutput, logging
Patrick von Platen's avatar
Patrick von Platen committed
38
39
40
41
42
43
44
45
46
47
48
49


logger = logging.get_logger(__name__)


@flax.struct.dataclass
class FlaxGreedySearchOutput(ModelOutput):
    """
    Flax Base class for outputs of decoder-only generation models using greedy search.


    Args:
50
        sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
51
            The generated sequences.
Patrick von Platen's avatar
Patrick von Platen committed
52
53
    """

54
    sequences: jnp.ndarray = None
Patrick von Platen's avatar
Patrick von Platen committed
55
56
57
58
59
60
61
62
63


@flax.struct.dataclass
class FlaxSampleOutput(ModelOutput):
    """
    Flax Base class for outputs of decoder-only generation models using sampling.


    Args:
64
        sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
65
66
67
            The generated sequences.
    """

68
    sequences: jnp.ndarray = None
69
70
71
72
73
74
75
76
77


@flax.struct.dataclass
class FlaxBeamSearchOutput(ModelOutput):
    """
    Flax Base class for outputs of decoder-only generation models using greedy search.


    Args:
78
        sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
79
            The generated sequences.
80
        scores (`jnp.ndarray` of shape `(batch_size,)`):
81
            The scores (log probabilites) of the generated sequences.
Patrick von Platen's avatar
Patrick von Platen committed
82
83
    """

84
85
    sequences: jnp.ndarray = None
    scores: jnp.ndarray = None
Patrick von Platen's avatar
Patrick von Platen committed
86
87
88
89


@flax.struct.dataclass
class GreedyState:
90
91
92
93
94
    cur_len: jnp.ndarray
    sequences: jnp.ndarray
    running_token: jnp.ndarray
    is_sent_finished: jnp.ndarray
    model_kwargs: Dict[str, jnp.ndarray]
Patrick von Platen's avatar
Patrick von Platen committed
95
96
97
98


@flax.struct.dataclass
class SampleState:
99
100
101
102
103
104
    cur_len: jnp.ndarray
    sequences: jnp.ndarray
    running_token: jnp.ndarray
    is_sent_finished: jnp.ndarray
    prng_key: jnp.ndarray
    model_kwargs: Dict[str, jnp.ndarray]
Patrick von Platen's avatar
Patrick von Platen committed
105
106


107
108
@flax.struct.dataclass
class BeamSearchState:
109
110
111
112
113
114
115
    cur_len: jnp.ndarray
    running_sequences: jnp.ndarray
    running_scores: jnp.ndarray
    sequences: jnp.ndarray
    scores: jnp.ndarray
    is_sent_finished: jnp.ndarray
    model_kwargs: Dict[str, jnp.ndarray]
116
117


Patrick von Platen's avatar
Patrick von Platen committed
118
119
class FlaxGenerationMixin:
    """
120
121
122
123
124
125
126
127
128
129
    A class containing all functions for auto-regressive text generation, to be used as a mixin in
    [`FlaxPreTrainedModel`].

    The class exposes [`~generation_flax_utils.FlaxGenerationMixin.generate`], which can be used for:
            - *greedy decoding* by calling [`~generation_flax_utils.FlaxGenerationMixin._greedy_search`] if
              `num_beams=1` and `do_sample=False`.
            - *multinomial sampling* by calling [`~generation_flax_utils.FlaxGenerationMixin._sample`] if `num_beams=1`
              and `do_sample=True`.
            - *beam-search decoding* by calling [`~generation_utils.FlaxGenerationMixin._beam_search`] if `num_beams>1`
              and `do_sample=False`.
Patrick von Platen's avatar
Patrick von Platen committed
130
131
132
133
134
135
136
137
138
139
140
141
    """

    @staticmethod
    def _run_loop_in_debug(cond_fn, body_fn, init_state):
        """
        Run generation in untraced mode. This should only be used for debugging purposes.
        """
        state = init_state
        while cond_fn(state):
            state = body_fn(state)
        return state

Suraj Patil's avatar
Suraj Patil committed
142
    def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, model_kwargs):
Daniel Stancl's avatar
Daniel Stancl committed
143
144
145
146
147
        encoder_kwargs = {
            argument: value
            for argument, value in model_kwargs.items()
            if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
        }
Suraj Patil's avatar
Suraj Patil committed
148
        model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs)
Daniel Stancl's avatar
Daniel Stancl committed
149
150
        return model_kwargs

151
152
153
154
    @staticmethod
    def _expand_to_num_beams(tensor, num_beams):
        return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])

155
156
157
    def _adapt_logits_for_beam_search(self, logits):
        """
        This function can be overwritten in the specific modeling_flax_<model-name>.py classes to allow for custom beam
Sylvain Gugger's avatar
Sylvain Gugger committed
158
        search behavior. Note that the only model that overwrites this method is [`~transformes.FlaxMarianMTModel`].
159
160
161
        """
        return logits

Patrick von Platen's avatar
Patrick von Platen committed
162
163
    def generate(
        self,
164
        input_ids: jnp.ndarray,
Patrick von Platen's avatar
Patrick von Platen committed
165
166
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
Daniel Stancl's avatar
Daniel Stancl committed
167
        bos_token_id: Optional[int] = None,
Patrick von Platen's avatar
Patrick von Platen committed
168
        eos_token_id: Optional[int] = None,
Daniel Stancl's avatar
Daniel Stancl committed
169
        decoder_start_token_id: Optional[int] = None,
Patrick von Platen's avatar
Patrick von Platen committed
170
        do_sample: Optional[bool] = None,
171
        prng_key: Optional[jnp.ndarray] = None,
Patrick von Platen's avatar
Patrick von Platen committed
172
173
174
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        temperature: Optional[float] = None,
175
176
177
178
179
180
181
        num_beams: Optional[int] = None,
        no_repeat_ngram_size: Optional[int] = None,
        min_length: Optional[int] = None,
        forced_bos_token_id: Optional[int] = None,
        forced_eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        early_stopping: Optional[bool] = None,
Patrick von Platen's avatar
Patrick von Platen committed
182
        trace: bool = True,
183
        params: Optional[Dict[str, jnp.ndarray]] = None,
Patrick von Platen's avatar
Patrick von Platen committed
184
185
186
        **model_kwargs,
    ):
        r"""
187
188
        Generates sequences of token ids for models with a language modeling head. The method supports the following
        generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
Patrick von Platen's avatar
Patrick von Platen committed
189

190
191
192
193
194
195
196
197
198
199
200
201
202
203
            - *greedy decoding* by calling [`~generation_flax_utils.FlaxGenerationMixin._greedy_search`] if
              `num_beams=1` and `do_sample=False`.
            - *multinomial sampling* by calling [`~generation_flax_utils.FlaxGenerationMixin._sample`] if `num_beams=1`
              and `do_sample=True`.
            - *beam-search decoding* by calling [`~generation_utils.FlaxGenerationMixin._beam_search`] if `num_beams>1`
              and `do_sample=False`.

        <Tip warning={true}>

        Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
        defined in the model's config (`config.json`) which in turn defaults to the
        [`~modeling_utils.PretrainedConfig`] of the model.

        </Tip>
Patrick von Platen's avatar
Patrick von Platen committed
204

Sylvain Gugger's avatar
Sylvain Gugger committed
205
206
        Most of these parameters are explained in more detail in [this blog
        post](https://huggingface.co/blog/how-to-generate).
Patrick von Platen's avatar
Patrick von Platen committed
207
208
209

        Parameters:

210
            input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
Patrick von Platen's avatar
Patrick von Platen committed
211
                The sequence used as a prompt for the generation.
212
            max_length (`int`, *optional*, defaults to 20):
Patrick von Platen's avatar
Patrick von Platen committed
213
                The maximum length of the sequence to be generated.
214
            do_sample (`bool`, *optional*, defaults to `False`):
Patrick von Platen's avatar
Patrick von Platen committed
215
                Whether or not to use sampling ; use greedy decoding otherwise.
216
            temperature (`float`, *optional*, defaults to 1.0):
Patrick von Platen's avatar
Patrick von Platen committed
217
                The value used to module the next token probabilities.
218
            top_k (`int`, *optional*, defaults to 50):
Patrick von Platen's avatar
Patrick von Platen committed
219
                The number of highest probability vocabulary tokens to keep for top-k-filtering.
220
            top_p (`float`, *optional*, defaults to 1.0):
Sylvain Gugger's avatar
Sylvain Gugger committed
221
222
                If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
                are kept for generation.
223
224
225
226
227
228
229
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            bos_token_id (`int`, *optional*):
                The id of the *beginning-of-sequence* token.
            eos_token_id (`int`, *optional*):
                The id of the *end-of-sequence* token.
            num_beams (`int`, *optional*, defaults to 1):
230
                Number of beams for beam search. 1 means no beam search.
231
232
233
            decoder_start_token_id (`int`, *optional*):
                If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
            trace (`bool`, *optional*, defaults to `True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
234
235
                Whether to trace generation. Setting `trace=False` should only be used for debugging and will lead to a
                considerably slower runtime.
236
            params (`Dict[str, jnp.ndarray]`, *optional*):
237
                Optionally the model parameters can be passed. Can be useful for parallelized generation.
Patrick von Platen's avatar
Patrick von Platen committed
238
            model_kwargs:
239
240
241
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
                is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
                should be prefixed with *decoder_*. Also accepts `encoder_outputs` to skip encoder part.
Patrick von Platen's avatar
Patrick von Platen committed
242
243

        Return:
244
            [`~utils.ModelOutput`].
245
246
247
248
249
250
251
252
253
254
255
256
257

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM

        >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
        >>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
        >>> input_context = "The dog"
        >>> # encode input context
        >>> input_ids = tokenizer(input_context, return_tensors="np").input_ids
        >>> # generate candidates using sampling
        >>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
258
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
259
        ```"""
Patrick von Platen's avatar
Patrick von Platen committed
260
261
        # set init values
        max_length = max_length if max_length is not None else self.config.max_length
262
        min_length = min_length if min_length is not None else self.config.min_length
Daniel Stancl's avatar
Daniel Stancl committed
263
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
Patrick von Platen's avatar
Patrick von Platen committed
264
265
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
Daniel Stancl's avatar
Daniel Stancl committed
266
267
268
        decoder_start_token_id = (
            decoder_start_token_id if decoder_start_token_id else self.config.decoder_start_token_id
        )
Patrick von Platen's avatar
Patrick von Platen committed
269
270
        prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)

Daniel Stancl's avatar
Daniel Stancl committed
271
272
        if decoder_start_token_id is None and self.config.is_encoder_decoder:
            raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
273
274
275
276
277
        if min_length is not None and min_length > max_length:
            raise ValueError(
                f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
                f"length ({max_length})"
            )
Daniel Stancl's avatar
Daniel Stancl committed
278
279
280

        if self.config.is_encoder_decoder:
            # add encoder_outputs to model_kwargs
281
282
            if model_kwargs.get("encoder_outputs") is None:
                model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
Daniel Stancl's avatar
Daniel Stancl committed
283
284
285
            # prepare decoder_input_ids for generation
            input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

Patrick von Platen's avatar
Patrick von Platen committed
286
        do_sample = do_sample if do_sample is not None else self.config.do_sample
287
        num_beams = num_beams if num_beams is not None else self.config.num_beams
Patrick von Platen's avatar
Patrick von Platen committed
288

289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        if not do_sample and num_beams == 1:
            logits_processor = self._get_logits_processor(
                no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
            )
            return self._greedy_search(
                input_ids,
                max_length,
                pad_token_id,
                eos_token_id,
                logits_processor=logits_processor,
                trace=trace,
                params=params,
                model_kwargs=model_kwargs,
            )
        elif do_sample and num_beams == 1:
Patrick von Platen's avatar
Patrick von Platen committed
304
            logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
305
306
307
            logits_processor = self._get_logits_processor(
                no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
            )
Patrick von Platen's avatar
Patrick von Platen committed
308
309
310
311
312
313
314
            return self._sample(
                input_ids,
                max_length,
                pad_token_id,
                eos_token_id,
                prng_key,
                logits_warper=logits_warper,
315
                logits_processor=logits_processor,
Patrick von Platen's avatar
Patrick von Platen committed
316
                trace=trace,
317
318
                params=params,
                model_kwargs=model_kwargs,
Patrick von Platen's avatar
Patrick von Platen committed
319
            )
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
        elif not do_sample and num_beams > 1:
            # broadcast input_ids & encoder_outputs
            input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)

            if "encoder_outputs" in model_kwargs:
                model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
                    model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams
                )

            if "attention_mask" in model_kwargs:
                model_kwargs["attention_mask"] = self._expand_to_num_beams(
                    model_kwargs["attention_mask"], num_beams=num_beams
                )

            logits_processor = self._get_logits_processor(
                no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
            )

            return self._beam_search(
339
340
341
342
                input_ids,
                max_length,
                pad_token_id,
                eos_token_id,
343
344
345
                length_penalty=length_penalty,
                early_stopping=early_stopping,
                logits_processor=logits_processor,
346
347
348
                trace=trace,
                params=params,
                model_kwargs=model_kwargs,
Patrick von Platen's avatar
Patrick von Platen committed
349
            )
350
351
        else:
            raise NotImplementedError("`Beam sampling is currently not implemented.")
Patrick von Platen's avatar
Patrick von Platen committed
352
353

    def _get_logits_warper(
354
        self, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
Patrick von Platen's avatar
Patrick von Platen committed
355
356
    ) -> FlaxLogitsProcessorList:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
357
358
        This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`]
        instances used for multinomial sampling.
Patrick von Platen's avatar
Patrick von Platen committed
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
        """

        # init warp parameters
        top_k = top_k if top_k is not None else self.config.top_k
        top_p = top_p if top_p is not None else self.config.top_p
        temperature = temperature if temperature is not None else self.config.temperature
        # instantiate warpers list
        warpers = FlaxLogitsProcessorList()

        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
        if temperature is not None and temperature != 1.0:
            warpers.append(FlaxTemperatureLogitsWarper(temperature))
        if top_k is not None and top_k != 0:
            warpers.append(FlaxTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1))
        if top_p is not None and top_p < 1.0:
            warpers.append(FlaxTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1))

        return warpers

379
380
381
382
383
384
385
386
387
388
    def _get_logits_processor(
        self,
        no_repeat_ngram_size: int,
        min_length: int,
        max_length: int,
        eos_token_id: int,
        forced_bos_token_id: int,
        forced_eos_token_id: int,
    ) -> FlaxLogitsProcessorList:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
389
390
        This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`]
        instances used to modify the scores of the language model head.
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
        """
        processors = FlaxLogitsProcessorList()

        # init warp parameters
        no_repeat_ngram_size = (
            no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
        )
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
        forced_bos_token_id = (
            forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
        )
        forced_eos_token_id = (
            forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
        )

        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
        if min_length is not None and eos_token_id is not None and min_length > -1:
            processors.append(FlaxMinLengthLogitsProcessor(min_length, eos_token_id))
        if forced_bos_token_id is not None:
            processors.append(FlaxForcedBOSTokenLogitsProcessor(forced_bos_token_id))
        if forced_eos_token_id is not None:
            processors.append(FlaxForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
        return processors

Patrick von Platen's avatar
Patrick von Platen committed
416
417
418
419
420
421
    def _greedy_search(
        self,
        input_ids: None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
422
        logits_processor: Optional[FlaxLogitsProcessorList] = None,
Patrick von Platen's avatar
Patrick von Platen committed
423
        trace: bool = True,
424
425
        params: Optional[Dict[str, jnp.ndarray]] = None,
        model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
Patrick von Platen's avatar
Patrick von Platen committed
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
    ):
        # init values
        max_length = max_length if max_length is not None else self.config.max_length
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id

        batch_size, cur_len = input_ids.shape

        eos_token_id = jnp.array(eos_token_id)
        pad_token_id = jnp.array(pad_token_id)
        cur_len = jnp.array(cur_len)

        # per batch-item holding current token in loop.
        sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
        sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))

        # per batch-item state bit indicating if sentence has finished.
        is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)

Daniel Stancl's avatar
Daniel Stancl committed
445
446
447
        # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
        # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
        model = self.decode if self.config.is_encoder_decoder else self
Patrick von Platen's avatar
Patrick von Platen committed
448
        # initialize model specific kwargs
Daniel Stancl's avatar
Daniel Stancl committed
449
        model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
Patrick von Platen's avatar
Patrick von Platen committed
450
451
452
453
454

        # initialize state
        state = GreedyState(
            cur_len=cur_len,
            sequences=sequences,
455
            running_token=input_ids,
Patrick von Platen's avatar
Patrick von Platen committed
456
457
458
459
460
461
462
463
464
465
466
467
468
            is_sent_finished=is_sent_finished,
            model_kwargs=model_kwargs,
        )

        def greedy_search_cond_fn(state):
            """state termination condition fn."""
            has_reached_max_length = state.cur_len == max_length
            all_sequence_finished = jnp.all(state.is_sent_finished)
            finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
            return ~finish_generation

        def greedy_search_body_fn(state):
            """state update fn."""
469
470
471
472
473
474
475
            model_outputs = model(state.running_token, params=params, **state.model_kwargs)
            logits = model_outputs.logits[:, -1]

            # apply min_length, ...
            logits = logits_processor(state.sequences, logits, state.cur_len)

            next_token = jnp.argmax(logits, axis=-1)
Patrick von Platen's avatar
Patrick von Platen committed
476

477
            next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
Patrick von Platen's avatar
Patrick von Platen committed
478
479
480
481
            next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
            next_token = next_token[:, None]

            next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
Daniel Stancl's avatar
Daniel Stancl committed
482
            next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
Patrick von Platen's avatar
Patrick von Platen committed
483
484
485
            return GreedyState(
                cur_len=state.cur_len + 1,
                sequences=next_sequences,
486
                running_token=next_token,
Patrick von Platen's avatar
Patrick von Platen committed
487
488
489
490
491
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
            )

        # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
Daniel Stancl's avatar
Daniel Stancl committed
492
493
        if input_ids.shape[1] > 1:
            state = greedy_search_body_fn(state)
Patrick von Platen's avatar
Patrick von Platen committed
494
495
496
497
498
499
500
501
502
503
504
505
506
507

        if not trace:
            state = self._run_loop_in_debug(greedy_search_cond_fn, greedy_search_body_fn, state)
        else:
            state = lax.while_loop(greedy_search_cond_fn, greedy_search_body_fn, state)

        return FlaxGreedySearchOutput(sequences=state.sequences)

    def _sample(
        self,
        input_ids: None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
508
        prng_key: Optional[jnp.ndarray] = None,
509
        logits_processor: Optional[FlaxLogitsProcessorList] = None,
Patrick von Platen's avatar
Patrick von Platen committed
510
511
        logits_warper: Optional[FlaxLogitsProcessorList] = None,
        trace: bool = True,
512
513
        params: Optional[Dict[str, jnp.ndarray]] = None,
        model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
Patrick von Platen's avatar
Patrick von Platen committed
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    ):
        # init values
        max_length = max_length if max_length is not None else self.config.max_length
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
        prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)

        batch_size, cur_len = input_ids.shape

        eos_token_id = jnp.array(eos_token_id)
        pad_token_id = jnp.array(pad_token_id)
        cur_len = jnp.array(cur_len)

        # per batch-item holding current token in loop.
        sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
        sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))

        # per batch-item state bit indicating if sentence has finished.
        is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)

Daniel Stancl's avatar
Daniel Stancl committed
534
535
536
        # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
        # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
        model = self.decode if self.config.is_encoder_decoder else self
Patrick von Platen's avatar
Patrick von Platen committed
537
538

        # initialize model specific kwargs
Daniel Stancl's avatar
Daniel Stancl committed
539
        model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
Patrick von Platen's avatar
Patrick von Platen committed
540
541
542
543
544

        # initialize state
        state = SampleState(
            cur_len=cur_len,
            sequences=sequences,
545
            running_token=input_ids,
Patrick von Platen's avatar
Patrick von Platen committed
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
            is_sent_finished=is_sent_finished,
            prng_key=prng_key,
            model_kwargs=model_kwargs,
        )

        def sample_search_cond_fn(state):
            """state termination condition fn."""
            has_reached_max_length = state.cur_len == max_length
            all_sequence_finished = jnp.all(state.is_sent_finished)
            finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
            return ~finish_generation

        def sample_search_body_fn(state):
            """state update fn."""
            prng_key, prng_key_next = jax.random.split(state.prng_key)
561
            model_outputs = model(state.running_token, params=params, **state.model_kwargs)
Patrick von Platen's avatar
Patrick von Platen committed
562
563
564

            logits = model_outputs.logits[:, -1]

565
566
            # apply min_length, ...
            logits = logits_processor(state.sequences, logits, state.cur_len)
567
            # apply top_p, top_k, temperature
568
            logits = logits_warper(logits, logits, state.cur_len)
Patrick von Platen's avatar
Patrick von Platen committed
569

570
            next_token = jax.random.categorical(prng_key, logits, axis=-1)
Patrick von Platen's avatar
Patrick von Platen committed
571
572
573
574
575
576

            next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
            next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
            next_token = next_token[:, None]

            next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
Daniel Stancl's avatar
Daniel Stancl committed
577
            next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
Patrick von Platen's avatar
Patrick von Platen committed
578
579
580
581

            return SampleState(
                cur_len=state.cur_len + 1,
                sequences=next_sequences,
582
                running_token=next_token,
Patrick von Platen's avatar
Patrick von Platen committed
583
584
585
586
587
588
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
                prng_key=prng_key_next,
            )

        # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
Daniel Stancl's avatar
Daniel Stancl committed
589
590
        if input_ids.shape[1] > 1:
            state = sample_search_body_fn(state)
Patrick von Platen's avatar
Patrick von Platen committed
591
592
593
594
595
596
597

        if not trace:
            state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)
        else:
            state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)

        return FlaxSampleOutput(sequences=state.sequences)
598
599
600
601
602
603
604
605
606
607
608

    def _beam_search(
        self,
        input_ids: None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        early_stopping: Optional[bool] = None,
        logits_processor: Optional[FlaxLogitsProcessorList] = None,
        trace: bool = True,
609
610
        params: Optional[Dict[str, jnp.ndarray]] = None,
        model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
    ):
        """
        This beam search function is heavily inspired by Flax's official example:
        https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
        """

        def flatten_beam_dim(tensor):
            """Flattens the first two dimensions of a non-scalar array."""
            # ignore scalars (e.g. cache index)
            if tensor.ndim == 0:
                return tensor
            return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])

        def unflatten_beam_dim(tensor, batch_size, num_beams):
            """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
            # ignore scalars (e.g. cache index)
            if tensor.ndim == 0:
                return tensor
            return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])

        def gather_beams(nested, beam_indices, batch_size, new_num_beams):
            """
            Gathers the beam slices indexed by beam_indices into new beam array.
            """
            batch_indices = jnp.reshape(
                jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
            )

            def gather_fn(tensor):
                # ignore scalars (e.g. cache index)
                if tensor.ndim == 0:
                    return tensor
                else:
                    return tensor[batch_indices, beam_indices]

            return jax.tree_map(gather_fn, nested)

        # init values
        max_length = max_length if max_length is not None else self.config.max_length
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping

        batch_size, num_beams, cur_len = input_ids.shape

        eos_token_id = jnp.array(eos_token_id)
        pad_token_id = jnp.array(pad_token_id)
        cur_len = jnp.array(cur_len)

        # per batch,beam-item holding current token in loop.
        sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
        running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
        running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))

        # per batch,beam-item state bit indicating if sentence has finished.
        is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)

        # per batch,beam-item score, logprobs
        running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
        scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)

        # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
        # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
        model = self.decode if self.config.is_encoder_decoder else self

        # flatten beam dim
        if "encoder_outputs" in model_kwargs:
            model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
                model_kwargs["encoder_outputs"]["last_hidden_state"]
            )
        if "attention_mask" in model_kwargs:
            model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])

        # initialize model specific kwargs
        model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)

        # initialize state
        state = BeamSearchState(
            cur_len=cur_len,
            running_sequences=running_sequences,
            running_scores=running_scores,
            sequences=sequences,
            scores=scores,
            is_sent_finished=is_sent_finished,
            model_kwargs=model_kwargs,
        )

        def beam_search_cond_fn(state):
            """beam search state termination condition fn."""

            # 1. is less than max length?
            not_max_length_yet = state.cur_len < max_length

            # 2. can the new beams still improve?
706
            best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
707
708
709
710
711
712
713
714
715
716
            worst_finished_score = jnp.where(
                state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
            )
            improvement_still_possible = jnp.all(worst_finished_score < best_running_score)

            # 3. is there still a beam that has not finished?
            still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)

            return not_max_length_yet & still_open_beam & improvement_still_possible

717
        def beam_search_body_fn(state, input_ids_length=1):
718
719
720
721
722
723
724
725
            """beam search state update fn."""
            # 1. Forward current tokens
            # Collect the current position slice along length to feed the fast
            # autoregressive decoder model.  Flatten the beam dimension into batch
            # dimension for feeding into the model.
            # unflatten beam dimension
            # Unflatten beam dimension in attention cache arrays
            input_token = flatten_beam_dim(
726
727
728
729
730
                lax.dynamic_slice(
                    state.running_sequences,
                    (0, 0, state.cur_len - input_ids_length),
                    (batch_size, num_beams, input_ids_length),
                )
731
732
            )
            model_outputs = model(input_token, params=params, **state.model_kwargs)
733
734

            logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
735
736
737
738
            cache = jax.tree_map(
                lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
            )

739
740
741
            # adapt logits for FlaxMarianMTModel
            logits = self._adapt_logits_for_beam_search(logits)

742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
            # 2. Compute log probs
            # get log probabilities from logits,
            # process logits with processors (*e.g.* min_length, ...), and
            # add new logprobs to existing running logprobs scores.
            log_probs = jax.nn.log_softmax(logits)
            log_probs = logits_processor(
                flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
            )
            log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
            log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
            vocab_size = log_probs.shape[2]
            log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))

            # 3. Retrieve top-K
            # Each item in batch has num_beams * vocab_size candidate sequences.
            # For each item, get the top 2*k candidates with the highest log-
            # probabilities. We gather the top 2*K beams here so that even if the best
            # K sequences reach EOS simultaneously, we have another K sequences
            # remaining to continue the live beam search.
            # Gather the top 2*K scores from _all_ beams.
            # Gather 2*k top beams.
            # Recover the beam index by floor division.
            # Recover token id by modulo division and expand Id array for broadcasting.
            # Update sequences for the 2*K top-k new sequences.
            beams_to_keep = 2 * num_beams
            topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
            topk_beam_indices = topk_indices // vocab_size
            topk_running_sequences = gather_beams(
                state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
            )
            topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
            topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))

            # 4. Check which sequences have ended
            # Update current sequences:
            # Did any of these sequences reach an end marker?
            # To prevent these just finished sequences from being added to the current sequences
            # set of active beam search sequences, set their log probs to a very large
            # negative value.
            did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
782
            running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
783
784
785
            # 5. Get running sequences scores for next
            # Determine the top k beam indices (from top 2*k beams) from log probs
            # and gather top k beams (from top 2*k beams).
786
            next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1)
787
            next_running_sequences, next_running_scores = gather_beams(
788
                [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
789
790
791
792
793
794
795
            )

            # 6. Process topk logits
            # Further process log probs:
            # - add length penalty
            # - make sure no scores can be added anymore if beam is full
            # - make sure still running sequences cannot be chosen as finalized beam
796
            topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
            beams_in_batch_are_full = (
                jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
                & early_stopping
            )
            add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
            topk_log_probs += add_penalty * np.array(-1.0e7)

            # 7. Get scores, sequences, is sentence finished for next.
            # Combine sequences, scores, and flags along the beam dimension and compare
            # new finished sequence scores to existing finished scores and select the
            # best from the new set of beams
            merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
            merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
            merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
            topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
            next_sequences, next_scores, next_is_sent_finished = gather_beams(
                [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
            )

            # 8. Update model kwargs.
            # Determine the top k beam indices from the original set of all beams.
            # With these, gather the top k beam-associated caches.
            next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
            next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
            model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
            next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)

            return BeamSearchState(
                cur_len=state.cur_len + 1,
                running_scores=next_running_scores,
                running_sequences=next_running_sequences,
                scores=next_scores,
                sequences=next_sequences,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
            )

        # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
835
836
        if input_ids.shape[-1] > 1:
            state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853

        if not trace:
            state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
        else:
            state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)

        # Account for the edge-case where there are no finished sequences for a
        # particular batch item. If so, return running sequences for that batch item.
        none_finished = jnp.any(state.is_sent_finished, axis=1)
        sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
        scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)

        # take best beam for each batch
        sequences = sequences[:, -1]
        scores = scores[:, -1]

        return FlaxBeamSearchOutput(sequences=sequences, scores=scores)