Unverified Commit 996a315e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Flax Generate (#11777)



* fix_torch_device_generate_test

* remove @

* add

* indexing

* correct a couple of tests

* fix tests

* add logits processor

* finish top_k, top_p, temp

* add docs

* correct flax prng key default

* improve generate

* add generation docs

* add docs

* make style

* revert model outputs change

* make style

* correct typo

* fix tests

* fix slow test

* add raise

* finish generation
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
parent 2df54691
...@@ -78,6 +78,9 @@ GreedySearchOutput ...@@ -78,6 +78,9 @@ GreedySearchOutput
.. autoclass:: transformers.generation_utils.GreedySearchEncoderDecoderOutput .. autoclass:: transformers.generation_utils.GreedySearchEncoderDecoderOutput
:members: :members:
.. autoclass:: transformers.generation_flax_utils.FlaxGreedySearchOutput
:members:
SampleOutput SampleOutput
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...@@ -88,6 +91,9 @@ SampleOutput ...@@ -88,6 +91,9 @@ SampleOutput
.. autoclass:: transformers.generation_utils.SampleEncoderDecoderOutput .. autoclass:: transformers.generation_utils.SampleEncoderDecoderOutput
:members: :members:
.. autoclass:: transformers.generation_flax_utils.FlaxSampleOutput
:members:
BeamSearchOutput BeamSearchOutput
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...@@ -160,6 +166,24 @@ generation. ...@@ -160,6 +166,24 @@ generation.
.. autoclass:: transformers.InfNanRemoveLogitsProcessor .. autoclass:: transformers.InfNanRemoveLogitsProcessor
:members: __call__ :members: __call__
.. autoclass:: transformers.FlaxLogitsProcessor
:members: __call__
.. autoclass:: transformers.FlaxLogitsProcessorList
:members: __call__
.. autoclass:: transformers.FlaxLogitsWarper
:members: __call__
.. autoclass:: transformers.FlaxTemperatureLogitsWarper
:members: __call__
.. autoclass:: transformers.FlaxTopPLogitsWarper
:members: __call__
.. autoclass:: transformers.FlaxTopKLogitsWarper
:members: __call__
StoppingCriteria StoppingCriteria
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -26,8 +26,9 @@ are common among all the models to: ...@@ -26,8 +26,9 @@ are common among all the models to:
The other methods that are common to each model are defined in :class:`~transformers.modeling_utils.ModuleUtilsMixin` The other methods that are common to each model are defined in :class:`~transformers.modeling_utils.ModuleUtilsMixin`
(for the PyTorch models) and :class:`~transformers.modeling_tf_utils.TFModuleUtilsMixin` (for the TensorFlow models) or (for the PyTorch models) and :class:`~transformers.modeling_tf_utils.TFModuleUtilsMixin` (for the TensorFlow models) or
for text generation, :class:`~transformers.generation_utils.GenerationMixin` (for the PyTorch models) and for text generation, :class:`~transformers.generation_utils.GenerationMixin` (for the PyTorch models),
:class:`~transformers.generation_tf_utils.TFGenerationMixin` (for the TensorFlow models) :class:`~transformers.generation_tf_utils.TFGenerationMixin` (for the TensorFlow models) and
:class:`~transformers.generation_flax_utils.FlaxGenerationMixin` (for the Flax/JAX models).
PreTrainedModel PreTrainedModel
...@@ -74,6 +75,9 @@ Generation ...@@ -74,6 +75,9 @@ Generation
.. autoclass:: transformers.generation_tf_utils.TFGenerationMixin .. autoclass:: transformers.generation_tf_utils.TFGenerationMixin
:members: :members:
.. autoclass:: transformers.generation_flax_utils.FlaxGenerationMixin
:members:
Pushing to the Hub Pushing to the Hub
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -1437,6 +1437,14 @@ else: ...@@ -1437,6 +1437,14 @@ else:
# FLAX-backed objects # FLAX-backed objects
if is_flax_available(): if is_flax_available():
_import_structure["generation_flax_logits_process"] = [
"FlaxLogitsProcessor",
"FlaxLogitsProcessorList",
"FlaxLogitsWarper",
"FlaxTemperatureLogitsWarper",
"FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper",
]
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"] _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
_import_structure["models.auto"].extend( _import_structure["models.auto"].extend(
[ [
...@@ -2693,6 +2701,14 @@ if TYPE_CHECKING: ...@@ -2693,6 +2701,14 @@ if TYPE_CHECKING:
from .utils.dummy_tf_objects import * from .utils.dummy_tf_objects import *
if is_flax_available(): if is_flax_available():
from .generation_flax_logits_process import (
FlaxLogitsProcessor,
FlaxLogitsProcessorList,
FlaxLogitsWarper,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
)
from .modeling_flax_utils import FlaxPreTrainedModel from .modeling_flax_utils import FlaxPreTrainedModel
from .models.auto import ( from .models.auto import (
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
......
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from abc import ABC
import jax
import jax.lax as lax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from .file_utils import add_start_docstrings
from .utils.logging import get_logger
logger = get_logger(__name__)
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
search or log softmax for each vocabulary token when using beam search
kwargs:
Additional logits processor specific kwargs.
Return:
:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
"""
class FlaxLogitsProcessor(ABC):
"""Abstract base class for all logit processors that can be applied during generation."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
"""Flax method for processing logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class FlaxLogitsWarper(ABC):
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
"""Flax method for warping logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class FlaxLogitsProcessorList(list):
"""
This class can be used to create a list of :class:`~transformers.FlaxLogitsProcessor` or
:class:`~transformers.FlaxLogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits
from list and adds a specific `__call__` method to apply each :class:`~transformers.FlaxLogitsProcessor` or
:class:`~transformers.FlaxLogitsWarper` to the inputs.
"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, **kwargs) -> jax_xla.DeviceArray:
for processor in self:
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2:
assert all(
arg in kwargs for arg in list(function_args.keys())[2:]
), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor."
scores = processor(input_ids, scores, **kwargs)
else:
scores = processor(input_ids, scores)
return scores
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
r"""
:class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution).
Args:
temperature (:obj:`float`):
The value used to module the logits distribution.
"""
def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
self.temperature = temperature
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
scores = scores / self.temperature
return scores
class FlaxTopPLogitsWarper(FlaxLogitsWarper):
"""
:class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <=
prob_cut_off.
Args:
top_p (:obj:`float`):
If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are
kept for generation.
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
self.top_p = top_p
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])
mask_scores = jnp.full_like(scores, self.filter_value)
cumulative_probs = jax.nn.softmax(topk_scores, axis=-1).cumsum(axis=-1)
score_mask = cumulative_probs < self.top_p
# include the token that is higher than top_p as well
score_mask |= jax.ops.index_update(jnp.roll(score_mask, 1), jax.ops.index[:, 0], True)
# min tokens to keep
score_mask = jax.ops.index_update(score_mask, jax.ops.index[:, : self.min_tokens_to_keep], True)
topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores)
next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1]
return next_scores
class FlaxTopKLogitsWarper(FlaxLogitsWarper):
r"""
:class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements.
Args:
top_k (:obj:`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
self.top_k = top_k
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
batch_size, vocab_size = scores.shape
next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
topk = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check
topk_scores, topk_indices = lax.top_k(scores, topk)
shift = jnp.broadcast_to((jnp.arange(batch_size) * vocab_size)[:, None], (batch_size, topk)).flatten()
topk_scores_flat = topk_scores.flatten()
topk_indices_flat = topk_indices.flatten() + shift
next_scores_flat = jax.ops.index_update(next_scores_flat, topk_indices_flat, topk_scores_flat)
next_scores = next_scores_flat.reshape(batch_size, vocab_size)
return next_scores
# coding=utf-8
# Copyright 2021 The Google AI Flax Team Authors, and The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional
import flax
import jax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from jax import lax
from .file_utils import ModelOutput
from .generation_flax_logits_process import (
FlaxLogitsProcessorList,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
)
from .utils import logging
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class FlaxGreedySearchOutput(ModelOutput):
"""
Flax Base class for outputs of decoder-only generation models using greedy search.
Args:
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
The generated sequences. If all batches finished early due to the :obj:`eos_token_id`, :obj:`sequences` is
padded to :obj:`max_length`.
"""
sequences: jax_xla.DeviceArray = None
@flax.struct.dataclass
class FlaxSampleOutput(ModelOutput):
"""
Flax Base class for outputs of decoder-only generation models using sampling.
Args:
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_length)`):
The generated sequences. If all batches finished early due to the :obj:`eos_token_id`, :obj:`sequences` is
padded to :obj:`max_length`.
"""
sequences: jax_xla.DeviceArray = None
@flax.struct.dataclass
class GreedyState:
cur_len: jax_xla.DeviceArray
sequences: jax_xla.DeviceArray
current_token: jax_xla.DeviceArray
is_sent_finished: jax_xla.DeviceArray
model_kwargs: Dict[str, jax_xla.DeviceArray]
@flax.struct.dataclass
class SampleState:
cur_len: jax_xla.DeviceArray
sequences: jax_xla.DeviceArray
current_token: jax_xla.DeviceArray
is_sent_finished: jax_xla.DeviceArray
prng_key: jax_xla.DeviceArray
model_kwargs: Dict[str, jax_xla.DeviceArray]
class FlaxGenerationMixin:
"""
A class containing all of the functions supporting generation, to be used as a mixin in
:class:`~transformers.FlaxPreTrainedModel`.
"""
@staticmethod
def _run_loop_in_debug(cond_fn, body_fn, init_state):
"""
Run generation in untraced mode. This should only be used for debugging purposes.
"""
state = init_state
while cond_fn(state):
state = body_fn(state)
return state
def generate(
self,
input_ids: jax_xla.DeviceArray,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
do_sample: Optional[bool] = None,
prng_key: Optional[jax_xla.DeviceArray] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
trace: bool = True,
**model_kwargs,
):
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
and, multinomial sampling.
Apart from :obj:`input_ids`, all the arguments below will default to the value of the attribute of the same
name inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the
default values of those config.
Most of these parameters are explained in more detail in `this blog post
<https://huggingface.co/blog/how-to-generate>`__.
Parameters:
input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use sampling ; use greedy decoding otherwise.
temperature (:obj:`float`, `optional`, defaults to 1.0):
The value used to module the next token probabilities.
top_k (:obj:`int`, `optional`, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (:obj:`float`, `optional`, defaults to 1.0):
If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or
higher are kept for generation.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
bos_token_id (:obj:`int`, `optional`):
The id of the `beginning-of-sequence` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
a considerably slower runtime.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
Return:
:class:`~transformers.file_utils.ModelOutput`.
Examples::
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
>>> input_context = "The dog"
>>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="jax").input_ids
>>> # generate candidates using sampling
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
"""
# set init values
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
do_sample = do_sample if do_sample is not None else self.config.do_sample
if do_sample:
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
return self._sample(
input_ids,
max_length,
pad_token_id,
eos_token_id,
prng_key,
logits_warper=logits_warper,
model_kwargs=model_kwargs,
trace=trace,
)
else:
return self._greedy_search(
input_ids, max_length, pad_token_id, eos_token_id, trace=trace, model_kwargs=model_kwargs
)
def _get_logits_warper(
self, top_k: int = None, top_p: float = None, temperature: float = None
) -> FlaxLogitsProcessorList:
"""
This class returns a :obj:`~transformers.FlaxLogitsProcessorList` list object that contains all relevant
:obj:`~transformers.FlaxLogitsWarper` instances used for multinomial sampling.
"""
# init warp parameters
top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.top_p
temperature = temperature if temperature is not None else self.config.temperature
# instantiate warpers list
warpers = FlaxLogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if temperature is not None and temperature != 1.0:
warpers.append(FlaxTemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
warpers.append(FlaxTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1))
if top_p is not None and top_p < 1.0:
warpers.append(FlaxTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1))
return warpers
def _greedy_search(
self,
input_ids: None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
trace: bool = True,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
):
# init values
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
batch_size, cur_len = input_ids.shape
eos_token_id = jnp.array(eos_token_id)
pad_token_id = jnp.array(pad_token_id)
cur_len = jnp.array(cur_len)
# per batch-item holding current token in loop.
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
# per batch-item state bit indicating if sentence has finished.
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
model = self
# initialize model specific kwargs
model_kwargs = model.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
# initialize state
state = GreedyState(
cur_len=cur_len,
sequences=sequences,
current_token=input_ids,
is_sent_finished=is_sent_finished,
model_kwargs=model_kwargs,
)
def greedy_search_cond_fn(state):
"""state termination condition fn."""
has_reached_max_length = state.cur_len == max_length
all_sequence_finished = jnp.all(state.is_sent_finished)
finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
return ~finish_generation
def greedy_search_body_fn(state):
"""state update fn."""
model_outputs = model(state.current_token, **state.model_kwargs)
next_token = jnp.argmax(model_outputs.logits[:, -1], axis=-1)
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
next_token = next_token[:, None]
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
next_model_kwargs = model.update_inputs_for_generation(model_outputs, model_kwargs)
return GreedyState(
cur_len=state.cur_len + 1,
sequences=next_sequences,
current_token=next_token,
is_sent_finished=next_is_sent_finished,
model_kwargs=next_model_kwargs,
)
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
state = greedy_search_body_fn(state)
if not trace:
state = self._run_loop_in_debug(greedy_search_cond_fn, greedy_search_body_fn, state)
else:
state = lax.while_loop(greedy_search_cond_fn, greedy_search_body_fn, state)
return FlaxGreedySearchOutput(sequences=state.sequences)
def _sample(
self,
input_ids: None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
prng_key: Optional[jax_xla.DeviceArray] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
logits_warper: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True,
):
# init values
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
batch_size, cur_len = input_ids.shape
eos_token_id = jnp.array(eos_token_id)
pad_token_id = jnp.array(pad_token_id)
cur_len = jnp.array(cur_len)
# per batch-item holding current token in loop.
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
# per batch-item state bit indicating if sentence has finished.
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
model = self
# initialize model specific kwargs
model_kwargs = model.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
# initialize state
state = SampleState(
cur_len=cur_len,
sequences=sequences,
current_token=input_ids,
is_sent_finished=is_sent_finished,
prng_key=prng_key,
model_kwargs=model_kwargs,
)
def sample_search_cond_fn(state):
"""state termination condition fn."""
has_reached_max_length = state.cur_len == max_length
all_sequence_finished = jnp.all(state.is_sent_finished)
finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
return ~finish_generation
def sample_search_body_fn(state):
"""state update fn."""
prng_key, prng_key_next = jax.random.split(state.prng_key)
model_outputs = model(state.current_token, **state.model_kwargs)
logits = model_outputs.logits[:, -1]
# apply top_k, top_k, temperature
logits = logits_warper(state.sequences, logits)
next_token = jax.random.categorical(prng_key, model_outputs.logits[:, -1], axis=-1)
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
next_token = next_token[:, None]
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
next_model_kwargs = model.update_inputs_for_generation(model_outputs, model_kwargs)
return SampleState(
cur_len=state.cur_len + 1,
sequences=next_sequences,
current_token=next_token,
is_sent_finished=next_is_sent_finished,
model_kwargs=next_model_kwargs,
prng_key=prng_key_next,
)
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
state = sample_search_body_fn(state)
if not trace:
state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)
else:
state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
return FlaxSampleOutput(sequences=state.sequences)
...@@ -41,6 +41,7 @@ from .file_utils import ( ...@@ -41,6 +41,7 @@ from .file_utils import (
is_remote_url, is_remote_url,
replace_return_docstrings, replace_return_docstrings,
) )
from .generation_flax_utils import FlaxGenerationMixin
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import logging from .utils import logging
...@@ -57,7 +58,7 @@ ACT2FN = { ...@@ -57,7 +58,7 @@ ACT2FN = {
} }
class FlaxPreTrainedModel(PushToHubMixin): class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
r""" r"""
Base class for all models. Base class for all models.
......
...@@ -20,7 +20,6 @@ import jax ...@@ -20,7 +20,6 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.linen import combine_masks, dot_product_attention, make_causal_mask from flax.linen import combine_masks, dot_product_attention, make_causal_mask
from flax.traverse_util import flatten_dict
from jax import lax from jax import lax
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
...@@ -322,13 +321,6 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): ...@@ -322,13 +321,6 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@property
def _attn_layer_name(self):
attn_layer_key_tuple = ("h", "0", "attn")
if self.base_model_prefix in set(self.params.keys()):
attn_layer_key_tuple = (self.base_model_prefix,) + attn_layer_key_tuple
return attn_layer_key_tuple
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
...@@ -381,28 +373,13 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): ...@@ -381,28 +373,13 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = input_ids.shape
if position_ids is None: if position_ids is None:
if past_key_values is not None and input_ids.shape[-1] == 1:
# if `past_key_values` are passed and input_ids are longer than 1, we are in cached auto-regressive generation. It has to be made sure that position_ids are set correctly
cache_shift = flatten_dict(unfreeze(past_key_values))[self._attn_layer_name + ("cache_index",)]
position_ids = jnp.broadcast_to(
jnp.arange(self.config.max_position_embeddings)[None, :],
(batch_size, self.config.max_position_embeddings),
)
position_ids = lax.dynamic_slice(position_ids, (0, cache_shift), (batch_size, 1))
else:
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
if attention_mask is None:
# if past_key_values are passed we need to create an attention_mask of the same length as `cache_length`
if past_key_values is not None: if past_key_values is not None:
cache_length = flatten_dict(unfreeze(past_key_values))[self._attn_layer_name + ("cached_key",)].shape[ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
1
]
else:
cache_length = sequence_length
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. But since GPT2 uses a causal mask, those positions are masked anyways. Thus we can create a single static attention_mask here, which is more efficient for compilation position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
attention_mask = jnp.ones((batch_size, cache_length))
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length))
# Handle any PRNG if needed # Handle any PRNG if needed
rngs = {} rngs = {}
...@@ -627,6 +604,32 @@ class FlaxGPT2LMHeadModule(nn.Module): ...@@ -627,6 +604,32 @@ class FlaxGPT2LMHeadModule(nn.Module):
class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel): class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
module_class = FlaxGPT2LMHeadModule module_class = FlaxGPT2LMHeadModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
past_key_values = self.init_cache(batch_size, max_length)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since GPT2 uses a causal mask, those positions are masked anyways.
# Thus we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
position_ids = attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
"position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
return model_kwargs
append_call_sample_docstring( append_call_sample_docstring(
FlaxGPT2LMHeadModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC FlaxGPT2LMHeadModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC
......
...@@ -2,6 +2,36 @@ ...@@ -2,6 +2,36 @@
from ..file_utils import requires_backends from ..file_utils import requires_backends
class FlaxLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLogitsProcessorList:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLogitsWarper:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxTemperatureLogitsWarper:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxTopKLogitsWarper:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxTopPLogitsWarper:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxPreTrainedModel: class FlaxPreTrainedModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
......
# coding=utf-8
# Copyright 2021 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import is_flax_available
from transformers.testing_utils import require_flax
from .test_modeling_flax_common import ids_tensor
if is_flax_available():
import jax
import jax.numpy as jnp
from transformers.generation_flax_logits_process import (
FlaxLogitsProcessorList,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
)
@require_flax
class LogitsProcessorTest(unittest.TestCase):
def _get_uniform_logits(self, batch_size: int, length: int):
scores = np.ones((batch_size, length)) / length
return scores
def test_temperature_dist_warper(self):
input_ids = None
length = 20
scores = self._get_uniform_logits(batch_size=2, length=length)
# tweak scores to not be uniform anymore
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
# compute softmax
probs = jax.nn.softmax(scores, axis=-1)
temp_dist_warper_sharper = FlaxTemperatureLogitsWarper(temperature=0.5)
temp_dist_warper_smoother = FlaxTemperatureLogitsWarper(temperature=1.3)
warped_prob_sharp = jax.nn.softmax(temp_dist_warper_sharper(input_ids, scores.copy()), axis=-1)
warped_prob_smooth = jax.nn.softmax(temp_dist_warper_smoother(input_ids, scores.copy()), axis=-1)
# uniform distribution stays uniform
self.assertTrue(jnp.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))
self.assertTrue(jnp.allclose(probs[0, :], warped_prob_smooth[0, :], atol=1e-3))
# sharp peaks get higher, valleys get lower
self.assertLess(probs[1, :].max(), warped_prob_sharp[1, :].max())
self.assertGreater(probs[1, :].min(), warped_prob_sharp[1, :].min())
# smooth peaks get lower, valleys get higher
self.assertGreater(probs[1, :].max(), warped_prob_smooth[1, :].max())
self.assertLess(probs[1, :].min(), warped_prob_smooth[1, :].min())
def test_top_k_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create ramp distribution
ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy()
ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size
top_k_warp = FlaxTopKLogitsWarper(3)
scores = top_k_warp(input_ids, ramp_logits)
# check that correct tokens are filtered
self.assertListEqual(jnp.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False])
self.assertListEqual(jnp.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True])
# check special case
length = 5
top_k_warp_safety_check = FlaxTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy()
scores = top_k_warp_safety_check(input_ids, ramp_logits)
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
self.assertListEqual((scores == 0.0).sum(axis=-1).tolist(), [2, 2])
def test_top_p_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))
top_p_warp = FlaxTopPLogitsWarper(0.7)
filtered_dist = np.exp(top_p_warp(input_ids, dist))
# dist should be filtered to keep min num values so that sum is >= 0.7
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = np.array([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]])
self.assertTrue(np.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
# check edge cases with negative and extreme logits
ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy() - (
vocab_size // 2
)
# make ramp_logits more extreme
ramp_logits[1] = ramp_logits[1] * 100.0
# make sure at least 2 tokens are kept
top_p_warp = FlaxTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
filtered_dist = top_p_warp(input_ids, ramp_logits)
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).sum(axis=-1).tolist(), [3, 2])
def test_processor_list(self):
batch_size = 4
sequence_length = 10
vocab_size = 15
# dummy input_ids and scores
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
input_ids_comp = input_ids.copy()
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_comp = scores.copy()
# instantiate all dist processors
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
top_k_warp = FlaxTopKLogitsWarper(3)
top_p_warp = FlaxTopPLogitsWarper(0.8)
# no processor list
scores = temp_dist_warp(input_ids, scores)
scores = top_k_warp(input_ids, scores)
scores = top_p_warp(input_ids, scores)
# with processor list
processor = FlaxLogitsProcessorList([temp_dist_warp, top_k_warp, top_p_warp])
scores_comp = processor(input_ids, scores_comp)
# scores should be equal
self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3))
# input_ids should never be changed
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
from transformers import is_flax_available
from transformers.testing_utils import require_flax
if is_flax_available():
import os
import jax
import jax.numpy as jnp
from jax import jit
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
def ids_tensor(shape, vocab_size, rng=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
output = np.array(values, dtype=jnp.int32).reshape(shape)
return output
def random_attention_mask(shape, rng=None):
attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
# make sure that at least one token is attended to for each batch
attn_mask[:, -1] = 1
return attn_mask
@require_flax
class FlaxGenerationTesterMixin:
model_tester = None
all_generative_model_classes = ()
def _get_input_ids_and_config(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# cut to half length & take max batch_size 3
max_batch_size = 2
sequence_length = inputs["input_ids"].shape[-1] // 2
input_ids = inputs["input_ids"][:max_batch_size, :sequence_length]
attention_mask = jnp.ones_like(input_ids)
attention_mask = attention_mask[:max_batch_size, :sequence_length]
# generate max 5 tokens
max_length = input_ids.shape[-1] + 5
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id
return config, input_ids, attention_mask, max_length
def test_greedy_generate(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = False
config.max_length = max_length
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_sample_generate(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = True
config.max_length = max_length
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_sample_generate_logits_warper(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = True
config.max_length = max_length
config.temperature = 0.8
config.top_k = 10
config.top_p = 0.3
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_greedy_generate_attn_mask(self):
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# pad attention mask on the left
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
config.do_sample = False
config.max_length = max_length
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_sample_generate_attn_mask(self):
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# pad attention mask on the left
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
config.do_sample = True
config.max_length = max_length
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
...@@ -19,16 +19,16 @@ import unittest ...@@ -19,16 +19,16 @@ import unittest
import numpy as np import numpy as np
import transformers import transformers
from transformers import GPT2Config, is_flax_available, is_torch_available from transformers import GPT2Config, GPT2Tokenizer, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
from .test_generation_flax_utils import FlaxGenerationTesterMixin
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available(): if is_flax_available():
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import lax
from transformers.modeling_flax_pytorch_utils import ( from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax, convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model, load_flax_weights_in_pytorch_model,
...@@ -116,8 +116,25 @@ class FlaxGPT2ModelTester: ...@@ -116,8 +116,25 @@ class FlaxGPT2ModelTester:
model = model_class_name(config) model = model_class_name(config)
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length) past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
outputs_cache = model(input_ids[:, :-1], past_key_values=past_key_values) attention_mask = jnp.ones((input_ids.shape[0], max_decoder_length), dtype="i4")
outputs_cache_next = model(input_ids[:, -1:], past_key_values=outputs_cache.past_key_values)
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
)
outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
)
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model(
input_ids[:, -1:],
attention_mask=attention_mask,
past_key_values=outputs_cache.past_key_values,
position_ids=position_ids,
)
outputs = model(input_ids) outputs = model(input_ids)
...@@ -134,10 +151,22 @@ class FlaxGPT2ModelTester: ...@@ -134,10 +151,22 @@ class FlaxGPT2ModelTester:
) )
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length) past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
)
outputs_cache = model(input_ids[:, :-1], attention_mask=attention_mask_cache, past_key_values=past_key_values) outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask_cache,
past_key_values=past_key_values,
position_ids=position_ids,
)
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model( outputs_cache_next = model(
input_ids[:, -1:], past_key_values=outputs_cache.past_key_values, attention_mask=attention_mask_cache input_ids[:, -1:],
past_key_values=outputs_cache.past_key_values,
attention_mask=attention_mask_cache,
position_ids=position_ids,
) )
outputs = model(input_ids, attention_mask=attention_mask) outputs = model(input_ids, attention_mask=attention_mask)
...@@ -145,66 +174,12 @@ class FlaxGPT2ModelTester: ...@@ -145,66 +174,12 @@ class FlaxGPT2ModelTester:
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
def check_use_cache_generation(self, config, input_ids):
prompt_length = 3
model = FlaxGPT2LMHeadModel(config)
max_length = 10
batch_size = 1
prompt_ids = input_ids[:1, :prompt_length]
# put all generation logic into one function
def generate(prompt_ids):
def first_pass(prompt_ids):
logits, cache = model(prompt_ids, past_key_values=past_key_values)[:2]
next_token = jnp.argmax(logits[:, -1:], axis=-1)
return next_token, cache
def greedy_search_cond_fn(state):
cur_len, _, _, _ = state
return ~(cur_len == max_length - 1)
def greedy_search_body_fn(state):
cur_len, sequences, current_token, cache = state
next_sequences = lax.dynamic_update_slice(sequences, current_token, (0, cur_len))
next_logits, next_cache = model(current_token, past_key_values=cache)[:2]
next_token = jnp.argmax(next_logits, axis=-1)
return cur_len + 1, next_sequences, next_token, next_cache
# init tensor to be filled with generation result
init_sequences = jnp.zeros((batch_size, max_length), dtype="i4")
init_sequences = lax.dynamic_update_slice(init_sequences, prompt_ids, (0, 0))
# init past key values for cache
past_key_values = model.init_cache(batch_size, max_length)
# first pass with long prompt
next_token, cache = first_pass(prompt_ids)
# prepare state for generation loop
init_state = (jnp.array(prompt_length), init_sequences, next_token, cache)
# fast generation
_, output_sequences, final_token, _ = lax.while_loop(
greedy_search_cond_fn, greedy_search_body_fn, init_state
)
# append last token
output_sequences = lax.dynamic_update_slice(output_sequences, final_token, (0, max_length - 1))
return output_sequences
jit_generate = jax.jit(generate)
output_sequences = jit_generate(prompt_ids)
self.parent.assertEqual(output_sequences.shape, (1, max_length))
@require_flax @require_flax
class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase): class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
all_model_classes = (FlaxGPT2Model, FlaxGPT2LMHeadModel) if is_flax_available() else () all_model_classes = (FlaxGPT2Model, FlaxGPT2LMHeadModel) if is_flax_available() else ()
all_generative_model_classes = (FlaxGPT2LMHeadModel,) if is_flax_available() else ()
def setUp(self): def setUp(self):
self.model_tester = FlaxGPT2ModelTester(self) self.model_tester = FlaxGPT2ModelTester(self)
...@@ -221,9 +196,27 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -221,9 +196,27 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
model_class_name, config, input_ids, attention_mask model_class_name, config, input_ids, attention_mask
) )
def test_use_cache_generation(self): @slow
config, input_ids, _ = self.model_tester.prepare_config_and_inputs() def test_batch_generation(self):
self.model_tester.check_use_cache_generation(config, input_ids) tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left")
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="jax", padding=True, truncation=True)
model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
model.do_sample = False
model.config.pad_token_id = model.config.eos_token_id
jit_generate = jax.jit(model.generate)
output_sequences = jit_generate(inputs["input_ids"], attention_mask=inputs["attention_mask"]).sequences
output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
expected_string = [
"Hello this is a long string of words. I'm going to try to explain what I mean.",
"Hey, I'm not sure if I'm going to be able to do",
]
self.assertListEqual(output_string, expected_string)
# overwrite from common since `attention_mask` in combination # overwrite from common since `attention_mask` in combination
# with `causal_mask` behaves slighly differently # with `causal_mask` behaves slighly differently
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment