Unverified Commit cd9274d0 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[FlaxBert] Add ForCausalLM (#16995)

* [FlaxBert] Add ForCausalLM

* make style

* fix output attentions

* Add RobertaForCausalLM

* remove comment

* fix fx-to-pt model loading

* remove comment

* add modeling tests

* add enc-dec model tests

* add big_bird

* add electra

* make style

* make repo-consitency

* add to docs

* remove roberta test

* quality

* amend cookiecutter

* fix attention_mask bug in flax bert model tester

* tighten pt-fx thresholds to 1e-5

* add 'copied from' statements

* amend 'copied from' statements

* amend 'copied from' statements

* quality
parent 31616b8d
...@@ -166,6 +166,11 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o ...@@ -166,6 +166,11 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o
[[autodoc]] FlaxBertForPreTraining [[autodoc]] FlaxBertForPreTraining
- __call__ - __call__
## FlaxBertForCausalLM
[[autodoc]] FlaxBertForCausalLM
- __call__
## FlaxBertForMaskedLM ## FlaxBertForMaskedLM
[[autodoc]] FlaxBertForMaskedLM [[autodoc]] FlaxBertForMaskedLM
......
...@@ -120,6 +120,11 @@ This model was contributed by [vasudevgupta](https://huggingface.co/vasudevgupta ...@@ -120,6 +120,11 @@ This model was contributed by [vasudevgupta](https://huggingface.co/vasudevgupta
[[autodoc]] FlaxBigBirdForPreTraining [[autodoc]] FlaxBigBirdForPreTraining
- __call__ - __call__
## FlaxBigBirdForCausalLM
[[autodoc]] FlaxBigBirdForCausalLM
- __call__
## FlaxBigBirdForMaskedLM ## FlaxBigBirdForMaskedLM
[[autodoc]] FlaxBigBirdForMaskedLM [[autodoc]] FlaxBigBirdForMaskedLM
......
...@@ -158,6 +158,11 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). The o ...@@ -158,6 +158,11 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). The o
[[autodoc]] FlaxElectraForPreTraining [[autodoc]] FlaxElectraForPreTraining
- __call__ - __call__
## FlaxElectraForCausalLM
[[autodoc]] FlaxElectraForCausalLM
- __call__
## FlaxElectraForMaskedLM ## FlaxElectraForMaskedLM
[[autodoc]] FlaxElectraForMaskedLM [[autodoc]] FlaxElectraForMaskedLM
......
...@@ -136,6 +136,11 @@ This model was contributed by [julien-c](https://huggingface.co/julien-c). The o ...@@ -136,6 +136,11 @@ This model was contributed by [julien-c](https://huggingface.co/julien-c). The o
[[autodoc]] FlaxRobertaModel [[autodoc]] FlaxRobertaModel
- __call__ - __call__
## FlaxRobertaForCausalLM
[[autodoc]] FlaxRobertaForCausalLM
- __call__
## FlaxRobertaForMaskedLM ## FlaxRobertaForMaskedLM
[[autodoc]] FlaxRobertaForMaskedLM [[autodoc]] FlaxRobertaForMaskedLM
......
...@@ -2314,6 +2314,7 @@ if is_flax_available(): ...@@ -2314,6 +2314,7 @@ if is_flax_available():
) )
_import_structure["models.bert"].extend( _import_structure["models.bert"].extend(
[ [
"FlaxBertForCausalLM",
"FlaxBertForMaskedLM", "FlaxBertForMaskedLM",
"FlaxBertForMultipleChoice", "FlaxBertForMultipleChoice",
"FlaxBertForNextSentencePrediction", "FlaxBertForNextSentencePrediction",
...@@ -2327,6 +2328,7 @@ if is_flax_available(): ...@@ -2327,6 +2328,7 @@ if is_flax_available():
) )
_import_structure["models.big_bird"].extend( _import_structure["models.big_bird"].extend(
[ [
"FlaxBigBirdForCausalLM",
"FlaxBigBirdForMaskedLM", "FlaxBigBirdForMaskedLM",
"FlaxBigBirdForMultipleChoice", "FlaxBigBirdForMultipleChoice",
"FlaxBigBirdForPreTraining", "FlaxBigBirdForPreTraining",
...@@ -2370,6 +2372,7 @@ if is_flax_available(): ...@@ -2370,6 +2372,7 @@ if is_flax_available():
) )
_import_structure["models.electra"].extend( _import_structure["models.electra"].extend(
[ [
"FlaxElectraForCausalLM",
"FlaxElectraForMaskedLM", "FlaxElectraForMaskedLM",
"FlaxElectraForMultipleChoice", "FlaxElectraForMultipleChoice",
"FlaxElectraForPreTraining", "FlaxElectraForPreTraining",
...@@ -2412,6 +2415,7 @@ if is_flax_available(): ...@@ -2412,6 +2415,7 @@ if is_flax_available():
) )
_import_structure["models.roberta"].extend( _import_structure["models.roberta"].extend(
[ [
"FlaxRobertaForCausalLM",
"FlaxRobertaForMaskedLM", "FlaxRobertaForMaskedLM",
"FlaxRobertaForMultipleChoice", "FlaxRobertaForMultipleChoice",
"FlaxRobertaForQuestionAnswering", "FlaxRobertaForQuestionAnswering",
...@@ -4363,6 +4367,7 @@ if TYPE_CHECKING: ...@@ -4363,6 +4367,7 @@ if TYPE_CHECKING:
FlaxBeitPreTrainedModel, FlaxBeitPreTrainedModel,
) )
from .models.bert import ( from .models.bert import (
FlaxBertForCausalLM,
FlaxBertForMaskedLM, FlaxBertForMaskedLM,
FlaxBertForMultipleChoice, FlaxBertForMultipleChoice,
FlaxBertForNextSentencePrediction, FlaxBertForNextSentencePrediction,
...@@ -4374,6 +4379,7 @@ if TYPE_CHECKING: ...@@ -4374,6 +4379,7 @@ if TYPE_CHECKING:
FlaxBertPreTrainedModel, FlaxBertPreTrainedModel,
) )
from .models.big_bird import ( from .models.big_bird import (
FlaxBigBirdForCausalLM,
FlaxBigBirdForMaskedLM, FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice, FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining, FlaxBigBirdForPreTraining,
...@@ -4411,6 +4417,7 @@ if TYPE_CHECKING: ...@@ -4411,6 +4417,7 @@ if TYPE_CHECKING:
FlaxDistilBertPreTrainedModel, FlaxDistilBertPreTrainedModel,
) )
from .models.electra import ( from .models.electra import (
FlaxElectraForCausalLM,
FlaxElectraForMaskedLM, FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice, FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining, FlaxElectraForPreTraining,
...@@ -4435,6 +4442,7 @@ if TYPE_CHECKING: ...@@ -4435,6 +4442,7 @@ if TYPE_CHECKING:
from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel
from .models.roberta import ( from .models.roberta import (
FlaxRobertaForCausalLM,
FlaxRobertaForMaskedLM, FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice, FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering, FlaxRobertaForQuestionAnswering,
......
...@@ -106,6 +106,55 @@ class FlaxBaseModelOutputWithPooling(ModelOutput): ...@@ -106,6 +106,55 @@ class FlaxBaseModelOutputWithPooling(ModelOutput):
attentions: Optional[Tuple[jnp.ndarray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) after further processing
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
the classification token after processing through a linear layer and a tanh activation function. The linear
layer weights are trained from the next sentence prediction (classification) objective during pretraining.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
"""
last_hidden_state: jnp.ndarray = None
pooler_output: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput): class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
""" """
......
...@@ -127,6 +127,10 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -127,6 +127,10 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("gptj", "FlaxGPTJForCausalLM"), ("gptj", "FlaxGPTJForCausalLM"),
("xglm", "FlaxXGLMForCausalLM"), ("xglm", "FlaxXGLMForCausalLM"),
("bart", "FlaxBartForCausalLM"), ("bart", "FlaxBartForCausalLM"),
("bert", "FlaxBertForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"),
("big_bird", "FlaxBigBirdForCausalLM"),
("electra", "FlaxElectraForCausalLM"),
] ]
) )
......
...@@ -65,6 +65,7 @@ if is_tf_available(): ...@@ -65,6 +65,7 @@ if is_tf_available():
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_bert"] = [ _import_structure["modeling_flax_bert"] = [
"FlaxBertForCausalLM",
"FlaxBertForMaskedLM", "FlaxBertForMaskedLM",
"FlaxBertForMultipleChoice", "FlaxBertForMultipleChoice",
"FlaxBertForNextSentencePrediction", "FlaxBertForNextSentencePrediction",
...@@ -119,6 +120,7 @@ if TYPE_CHECKING: ...@@ -119,6 +120,7 @@ if TYPE_CHECKING:
if is_flax_available(): if is_flax_available():
from .modeling_flax_bert import ( from .modeling_flax_bert import (
FlaxBertForCausalLM,
FlaxBertForMaskedLM, FlaxBertForMaskedLM,
FlaxBertForMultipleChoice, FlaxBertForMultipleChoice,
FlaxBertForNextSentencePrediction, FlaxBertForNextSentencePrediction,
......
...@@ -22,13 +22,16 @@ import flax.linen as nn ...@@ -22,13 +22,16 @@ import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from ...modeling_flax_outputs import ( from ...modeling_flax_outputs import (
FlaxBaseModelOutput, FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxBaseModelOutputWithPooling, FlaxBaseModelOutputWithPooling,
FlaxBaseModelOutputWithPoolingAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
FlaxMaskedLMOutput, FlaxMaskedLMOutput,
FlaxMultipleChoiceModelOutput, FlaxMultipleChoiceModelOutput,
FlaxNextSentencePredictorOutput, FlaxNextSentencePredictorOutput,
...@@ -212,9 +215,11 @@ class FlaxBertEmbeddings(nn.Module): ...@@ -212,9 +215,11 @@ class FlaxBertEmbeddings(nn.Module):
class FlaxBertSelfAttention(nn.Module): class FlaxBertSelfAttention(nn.Module):
config: BertConfig config: BertConfig
causal: bool = False
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self): def setup(self):
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
if self.config.hidden_size % self.config.num_attention_heads != 0: if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError( raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
...@@ -237,30 +242,113 @@ class FlaxBertSelfAttention(nn.Module): ...@@ -237,30 +242,113 @@ class FlaxBertSelfAttention(nn.Module):
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
if self.causal:
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
)
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
@nn.compact
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slighly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# update key, value caches with our new 1d spatial slices
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
def __call__( def __call__(
self, self,
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
key_value_states: Optional[jnp.array] = None,
init_cache: bool = False,
deterministic=True, deterministic=True,
output_attentions: bool = False, output_attentions: bool = False,
): ):
head_dim = self.config.hidden_size // self.config.num_attention_heads # if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
query_states = self.query(hidden_states).reshape( is_cross_attention = key_value_states is not None
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) batch_size = hidden_states.shape[0]
)
value_states = self.value(hidden_states).reshape( # get query proj
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) query_states = self.query(hidden_states)
# get key, value proj
if is_cross_attention:
# cross_attentions
key_states = self.key(key_value_states)
value_states = self.value(key_value_states)
else:
# self_attention
key_states = self.key(hidden_states)
value_states = self.value(hidden_states)
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
# handle cache prepare causal attention mask
if self.causal:
query_length, key_length = query_states.shape[1], key_states.shape[1]
if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
causal_mask = lax.dynamic_slice(
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
) )
key_states = self.key(hidden_states).reshape( else:
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) causal_mask = self.causal_mask[:, :, :query_length, :key_length]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
# combine masks if needed
if attention_mask is not None and self.causal:
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask)
elif self.causal:
attention_mask = causal_mask
elif attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
) )
# Convert the boolean attention mask to an attention bias. # Convert the boolean attention mask to an attention bias.
if attention_mask is not None: if attention_mask is not None:
# attention mask in the form of attention bias # attention mask in the form of attention bias
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
...@@ -318,10 +406,11 @@ class FlaxBertSelfOutput(nn.Module): ...@@ -318,10 +406,11 @@ class FlaxBertSelfOutput(nn.Module):
class FlaxBertAttention(nn.Module): class FlaxBertAttention(nn.Module):
config: BertConfig config: BertConfig
causal: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype) self.self = FlaxBertSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype) self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -329,6 +418,8 @@ class FlaxBertAttention(nn.Module): ...@@ -329,6 +418,8 @@ class FlaxBertAttention(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
key_value_states=None,
init_cache=False,
deterministic=True, deterministic=True,
output_attentions: bool = False, output_attentions: bool = False,
): ):
...@@ -339,6 +430,8 @@ class FlaxBertAttention(nn.Module): ...@@ -339,6 +430,8 @@ class FlaxBertAttention(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
key_value_states=key_value_states,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -396,27 +489,46 @@ class FlaxBertLayer(nn.Module): ...@@ -396,27 +489,46 @@ class FlaxBertLayer(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self): def setup(self):
self.attention = FlaxBertAttention(self.config, dtype=self.dtype) self.attention = FlaxBertAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype) self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)
self.output = FlaxBertOutput(self.config, dtype=self.dtype) self.output = FlaxBertOutput(self.config, dtype=self.dtype)
if self.config.add_cross_attention:
self.crossattention = FlaxBertAttention(self.config, causal=False, dtype=self.dtype)
def __call__( def __call__(
self, self,
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
): ):
# Self Attention
attention_outputs = self.attention( attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
attention_output = attention_outputs[0] attention_output = attention_outputs[0]
# Cross-Attention Block
if encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask,
key_value_states=encoder_hidden_states,
deterministic=deterministic,
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
hidden_states = self.intermediate(attention_output) hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
...@@ -424,6 +536,8 @@ class FlaxBertLayer(nn.Module): ...@@ -424,6 +536,8 @@ class FlaxBertLayer(nn.Module):
if output_attentions: if output_attentions:
outputs += (attention_outputs[1],) outputs += (attention_outputs[1],)
if encoder_hidden_states is not None:
outputs += (cross_attention_outputs[1],)
return outputs return outputs
...@@ -441,6 +555,9 @@ class FlaxBertLayerCollection(nn.Module): ...@@ -441,6 +555,9 @@ class FlaxBertLayerCollection(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask, head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
...@@ -448,6 +565,7 @@ class FlaxBertLayerCollection(nn.Module): ...@@ -448,6 +565,7 @@ class FlaxBertLayerCollection(nn.Module):
): ):
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# Check if head_mask has a correct number of layers specified if desired # Check if head_mask has a correct number of layers specified if desired
if head_mask is not None: if head_mask is not None:
...@@ -465,6 +583,9 @@ class FlaxBertLayerCollection(nn.Module): ...@@ -465,6 +583,9 @@ class FlaxBertLayerCollection(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None, layer_head_mask=head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -474,6 +595,9 @@ class FlaxBertLayerCollection(nn.Module): ...@@ -474,6 +595,9 @@ class FlaxBertLayerCollection(nn.Module):
if output_attentions: if output_attentions:
all_attentions += (layer_outputs[1],) all_attentions += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
...@@ -482,8 +606,11 @@ class FlaxBertLayerCollection(nn.Module): ...@@ -482,8 +606,11 @@ class FlaxBertLayerCollection(nn.Module):
if not return_dict: if not return_dict:
return tuple(v for v in outputs if v is not None) return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput( return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
) )
...@@ -499,6 +626,9 @@ class FlaxBertEncoder(nn.Module): ...@@ -499,6 +626,9 @@ class FlaxBertEncoder(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask, head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
...@@ -508,6 +638,9 @@ class FlaxBertEncoder(nn.Module): ...@@ -508,6 +638,9 @@ class FlaxBertEncoder(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -639,9 +772,26 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): ...@@ -639,9 +772,26 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init( if self.config.add_cross_attention:
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
encoder_attention_mask = attention_mask
module_init_outputs = self.module.init(
rngs,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
return_dict=False,
)
else:
module_init_outputs = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)["params"] )
random_params = module_init_outputs["params"]
if params is not None: if params is not None:
random_params = flatten_dict(unfreeze(random_params)) random_params = flatten_dict(unfreeze(random_params))
...@@ -653,6 +803,26 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): ...@@ -653,6 +803,26 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
else: else:
return random_params return random_params
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
def init_cache(self, batch_size, max_length):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
"""
# init input variables to retrieve cache
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
attention_mask = jnp.ones_like(input_ids, dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
)
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
self, self,
...@@ -661,12 +831,15 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): ...@@ -661,12 +831,15 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
params: dict = None, params: dict = None,
dropout_rng: jax.random.PRNGKey = None, dropout_rng: jax.random.PRNGKey = None,
train: bool = False, train: bool = False,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
past_key_values: dict = None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -692,20 +865,61 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): ...@@ -692,20 +865,61 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
if dropout_rng is not None: if dropout_rng is not None:
rngs["dropout"] = dropout_rng rngs["dropout"] = dropout_rng
return self.module.apply( inputs = {"params": params or self.params}
{"params": params or self.params},
if self.config.add_cross_attention:
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
# changed by FlaxBertAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
head_mask=jnp.array(head_mask, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
deterministic=not train,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
rngs=rngs,
mutable=mutable,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past_key_values = outputs
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past_key_values = outputs
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
else:
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"), jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"), jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"), token_type_ids=jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"), position_ids=jnp.array(position_ids, dtype="i4"),
jnp.array(head_mask, dtype="i4"), head_mask=jnp.array(head_mask, dtype="i4"),
not train, deterministic=not train,
output_attentions, output_attentions=output_attentions,
output_hidden_states, output_hidden_states=output_hidden_states,
return_dict, return_dict=return_dict,
rngs=rngs, rngs=rngs,
) )
return outputs
class FlaxBertModule(nn.Module): class FlaxBertModule(nn.Module):
config: BertConfig config: BertConfig
...@@ -721,9 +935,12 @@ class FlaxBertModule(nn.Module): ...@@ -721,9 +935,12 @@ class FlaxBertModule(nn.Module):
self, self,
input_ids, input_ids,
attention_mask, attention_mask,
token_type_ids: Optional[np.ndarray] = None, token_type_ids: Optional[jnp.ndarray] = None,
position_ids: Optional[np.ndarray] = None, position_ids: Optional[jnp.ndarray] = None,
head_mask: Optional[np.ndarray] = None, head_mask: Optional[jnp.ndarray] = None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
...@@ -745,6 +962,9 @@ class FlaxBertModule(nn.Module): ...@@ -745,6 +962,9 @@ class FlaxBertModule(nn.Module):
attention_mask, attention_mask,
head_mask=head_mask, head_mask=head_mask,
deterministic=deterministic, deterministic=deterministic,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -758,11 +978,12 @@ class FlaxBertModule(nn.Module): ...@@ -758,11 +978,12 @@ class FlaxBertModule(nn.Module):
return (hidden_states,) + outputs[1:] return (hidden_states,) + outputs[1:]
return (hidden_states, pooled) + outputs[1:] return (hidden_states, pooled) + outputs[1:]
return FlaxBaseModelOutputWithPooling( return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
pooler_output=pooled, pooler_output=pooled,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
) )
...@@ -1313,3 +1534,108 @@ append_call_sample_docstring( ...@@ -1313,3 +1534,108 @@ append_call_sample_docstring(
FlaxQuestionAnsweringModelOutput, FlaxQuestionAnsweringModelOutput,
_CONFIG_FOR_DOC, _CONFIG_FOR_DOC,
) )
class FlaxBertForCausalLMModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
token_type_ids: Optional[jnp.ndarray] = None,
head_mask: Optional[jnp.ndarray] = None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.bert(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
else:
shared_embedding = None
# Compute the prediction scores
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxCausalLMOutputWithCrossAttentions(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
@add_start_docstrings(
"""
Bert Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
autoregressive tasks.
""",
BERT_START_DOCSTRING,
)
class FlaxBertForCausalLM(FlaxBertPreTrainedModel):
module_class = FlaxBertForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
past_key_values = self.init_cache(batch_size, max_length)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since the decoder uses a causal mask, those positions are masked anyway.
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
position_ids = attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
"position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
return model_kwargs
append_call_sample_docstring(
FlaxBertForCausalLM,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxCausalLMOutputWithCrossAttentions,
_CONFIG_FOR_DOC,
)
...@@ -55,6 +55,7 @@ if is_torch_available(): ...@@ -55,6 +55,7 @@ if is_torch_available():
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_big_bird"] = [ _import_structure["modeling_flax_big_bird"] = [
"FlaxBigBirdForCausalLM",
"FlaxBigBirdForMaskedLM", "FlaxBigBirdForMaskedLM",
"FlaxBigBirdForMultipleChoice", "FlaxBigBirdForMultipleChoice",
"FlaxBigBirdForPreTraining", "FlaxBigBirdForPreTraining",
...@@ -92,6 +93,7 @@ if TYPE_CHECKING: ...@@ -92,6 +93,7 @@ if TYPE_CHECKING:
if is_flax_available(): if is_flax_available():
from .modeling_flax_big_bird import ( from .modeling_flax_big_bird import (
FlaxBigBirdForCausalLM,
FlaxBigBirdForMaskedLM, FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice, FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining, FlaxBigBirdForPreTraining,
......
...@@ -22,13 +22,16 @@ import flax.linen as nn ...@@ -22,13 +22,16 @@ import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from ...modeling_flax_outputs import ( from ...modeling_flax_outputs import (
FlaxBaseModelOutput, FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxBaseModelOutputWithPooling, FlaxBaseModelOutputWithPooling,
FlaxBaseModelOutputWithPoolingAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
FlaxMaskedLMOutput, FlaxMaskedLMOutput,
FlaxMultipleChoiceModelOutput, FlaxMultipleChoiceModelOutput,
FlaxSequenceClassifierOutput, FlaxSequenceClassifierOutput,
...@@ -234,9 +237,11 @@ class FlaxBigBirdEmbeddings(nn.Module): ...@@ -234,9 +237,11 @@ class FlaxBigBirdEmbeddings(nn.Module):
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->BigBird # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->BigBird
class FlaxBigBirdSelfAttention(nn.Module): class FlaxBigBirdSelfAttention(nn.Module):
config: BigBirdConfig config: BigBirdConfig
causal: bool = False
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self): def setup(self):
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
if self.config.hidden_size % self.config.num_attention_heads != 0: if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError( raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
...@@ -259,30 +264,113 @@ class FlaxBigBirdSelfAttention(nn.Module): ...@@ -259,30 +264,113 @@ class FlaxBigBirdSelfAttention(nn.Module):
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
if self.causal:
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
)
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
@nn.compact
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slighly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# update key, value caches with our new 1d spatial slices
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
def __call__( def __call__(
self, self,
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
key_value_states: Optional[jnp.array] = None,
init_cache: bool = False,
deterministic=True, deterministic=True,
output_attentions: bool = False, output_attentions: bool = False,
): ):
head_dim = self.config.hidden_size // self.config.num_attention_heads # if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
query_states = self.query(hidden_states).reshape( is_cross_attention = key_value_states is not None
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) batch_size = hidden_states.shape[0]
)
value_states = self.value(hidden_states).reshape( # get query proj
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) query_states = self.query(hidden_states)
# get key, value proj
if is_cross_attention:
# cross_attentions
key_states = self.key(key_value_states)
value_states = self.value(key_value_states)
else:
# self_attention
key_states = self.key(hidden_states)
value_states = self.value(hidden_states)
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
# handle cache prepare causal attention mask
if self.causal:
query_length, key_length = query_states.shape[1], key_states.shape[1]
if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
causal_mask = lax.dynamic_slice(
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
) )
key_states = self.key(hidden_states).reshape( else:
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) causal_mask = self.causal_mask[:, :, :query_length, :key_length]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
# combine masks if needed
if attention_mask is not None and self.causal:
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask)
elif self.causal:
attention_mask = causal_mask
elif attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
) )
# Convert the boolean attention mask to an attention bias. # Convert the boolean attention mask to an attention bias.
if attention_mask is not None: if attention_mask is not None:
# attention mask in the form of attention bias # attention mask in the form of attention bias
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
...@@ -1118,11 +1206,12 @@ class FlaxBigBirdSelfOutput(nn.Module): ...@@ -1118,11 +1206,12 @@ class FlaxBigBirdSelfOutput(nn.Module):
class FlaxBigBirdAttention(nn.Module): class FlaxBigBirdAttention(nn.Module):
config: BigBirdConfig config: BigBirdConfig
layer_id: int = None layer_id: int = None
causal: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
if self.config.attention_type == "original_full": if self.config.attention_type == "original_full":
self.self = FlaxBigBirdSelfAttention(self.config, dtype=self.dtype) self.self = FlaxBigBirdSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
elif self.config.attention_type == "block_sparse": elif self.config.attention_type == "block_sparse":
self.self = FlaxBigBirdBlockSparseAttention(self.config, block_sparse_seed=self.layer_id, dtype=self.dtype) self.self = FlaxBigBirdBlockSparseAttention(self.config, block_sparse_seed=self.layer_id, dtype=self.dtype)
else: else:
...@@ -1137,6 +1226,8 @@ class FlaxBigBirdAttention(nn.Module): ...@@ -1137,6 +1226,8 @@ class FlaxBigBirdAttention(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
key_value_states=None,
init_cache=False,
deterministic=True, deterministic=True,
output_attentions: bool = False, output_attentions: bool = False,
): ):
...@@ -1148,6 +1239,8 @@ class FlaxBigBirdAttention(nn.Module): ...@@ -1148,6 +1239,8 @@ class FlaxBigBirdAttention(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
key_value_states=key_value_states,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -1215,9 +1308,13 @@ class FlaxBigBirdLayer(nn.Module): ...@@ -1215,9 +1308,13 @@ class FlaxBigBirdLayer(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self): def setup(self):
self.attention = FlaxBigBirdAttention(self.config, layer_id=self.layer_id, dtype=self.dtype) self.attention = FlaxBigBirdAttention(
self.config, layer_id=self.layer_id, causal=self.config.is_decoder, dtype=self.dtype
)
self.intermediate = FlaxBigBirdIntermediate(self.config, dtype=self.dtype) self.intermediate = FlaxBigBirdIntermediate(self.config, dtype=self.dtype)
self.output = FlaxBigBirdOutput(self.config, dtype=self.dtype) self.output = FlaxBigBirdOutput(self.config, dtype=self.dtype)
if self.config.add_cross_attention:
self.crossattention = FlaxBigBirdAttention(self.config, causal=False, dtype=self.dtype)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer.__call__ with Bert->BigBird # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer.__call__ with Bert->BigBird
def __call__( def __call__(
...@@ -1225,18 +1322,35 @@ class FlaxBigBirdLayer(nn.Module): ...@@ -1225,18 +1322,35 @@ class FlaxBigBirdLayer(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
): ):
# Self Attention
attention_outputs = self.attention( attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
attention_output = attention_outputs[0] attention_output = attention_outputs[0]
# Cross-Attention Block
if encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask,
key_value_states=encoder_hidden_states,
deterministic=deterministic,
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
hidden_states = self.intermediate(attention_output) hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
...@@ -1244,6 +1358,8 @@ class FlaxBigBirdLayer(nn.Module): ...@@ -1244,6 +1358,8 @@ class FlaxBigBirdLayer(nn.Module):
if output_attentions: if output_attentions:
outputs += (attention_outputs[1],) outputs += (attention_outputs[1],)
if encoder_hidden_states is not None:
outputs += (cross_attention_outputs[1],)
return outputs return outputs
...@@ -1263,6 +1379,9 @@ class FlaxBigBirdLayerCollection(nn.Module): ...@@ -1263,6 +1379,9 @@ class FlaxBigBirdLayerCollection(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask, head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
...@@ -1270,6 +1389,7 @@ class FlaxBigBirdLayerCollection(nn.Module): ...@@ -1270,6 +1389,7 @@ class FlaxBigBirdLayerCollection(nn.Module):
): ):
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# Check if head_mask has a correct number of layers specified if desired # Check if head_mask has a correct number of layers specified if desired
if head_mask is not None: if head_mask is not None:
...@@ -1287,6 +1407,9 @@ class FlaxBigBirdLayerCollection(nn.Module): ...@@ -1287,6 +1407,9 @@ class FlaxBigBirdLayerCollection(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None, layer_head_mask=head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -1296,6 +1419,9 @@ class FlaxBigBirdLayerCollection(nn.Module): ...@@ -1296,6 +1419,9 @@ class FlaxBigBirdLayerCollection(nn.Module):
if output_attentions: if output_attentions:
all_attentions += (layer_outputs[1],) all_attentions += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
...@@ -1304,8 +1430,11 @@ class FlaxBigBirdLayerCollection(nn.Module): ...@@ -1304,8 +1430,11 @@ class FlaxBigBirdLayerCollection(nn.Module):
if not return_dict: if not return_dict:
return tuple(v for v in outputs if v is not None) return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput( return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
) )
...@@ -1322,6 +1451,9 @@ class FlaxBigBirdEncoder(nn.Module): ...@@ -1322,6 +1451,9 @@ class FlaxBigBirdEncoder(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask, head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
...@@ -1331,6 +1463,9 @@ class FlaxBigBirdEncoder(nn.Module): ...@@ -1331,6 +1463,9 @@ class FlaxBigBirdEncoder(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -1432,6 +1567,7 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): ...@@ -1432,6 +1567,7 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
...@@ -1443,9 +1579,26 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): ...@@ -1443,9 +1579,26 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init( if self.config.add_cross_attention:
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
encoder_attention_mask = attention_mask
module_init_outputs = self.module.init(
rngs,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
return_dict=False,
)
else:
module_init_outputs = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)["params"] )
random_params = module_init_outputs["params"]
if params is not None: if params is not None:
random_params = flatten_dict(unfreeze(random_params)) random_params = flatten_dict(unfreeze(random_params))
...@@ -1457,7 +1610,28 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): ...@@ -1457,7 +1610,28 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
else: else:
return random_params return random_params
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
def init_cache(self, batch_size, max_length):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
"""
# init input variables to retrieve cache
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
attention_mask = jnp.ones_like(input_ids, dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
)
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.__call__ with Bert->BigBird
def __call__( def __call__(
self, self,
input_ids, input_ids,
...@@ -1465,12 +1639,15 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): ...@@ -1465,12 +1639,15 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
params: dict = None, params: dict = None,
dropout_rng: jax.random.PRNGKey = None, dropout_rng: jax.random.PRNGKey = None,
train: bool = False, train: bool = False,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
past_key_values: dict = None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -1496,20 +1673,61 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): ...@@ -1496,20 +1673,61 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
if dropout_rng is not None: if dropout_rng is not None:
rngs["dropout"] = dropout_rng rngs["dropout"] = dropout_rng
return self.module.apply( inputs = {"params": params or self.params}
{"params": params or self.params},
if self.config.add_cross_attention:
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
# changed by FlaxBigBirdAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"), jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"), jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"), token_type_ids=jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"), position_ids=jnp.array(position_ids, dtype="i4"),
jnp.array(head_mask, dtype="i4"), head_mask=jnp.array(head_mask, dtype="i4"),
not train, encoder_hidden_states=encoder_hidden_states,
output_attentions, encoder_attention_mask=encoder_attention_mask,
output_hidden_states, deterministic=not train,
return_dict, output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
rngs=rngs,
mutable=mutable,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past_key_values = outputs
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past_key_values = outputs
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
else:
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
head_mask=jnp.array(head_mask, dtype="i4"),
deterministic=not train,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
rngs=rngs, rngs=rngs,
) )
return outputs
class FlaxBigBirdModule(nn.Module): class FlaxBigBirdModule(nn.Module):
config: BigBirdConfig config: BigBirdConfig
...@@ -1532,6 +1750,9 @@ class FlaxBigBirdModule(nn.Module): ...@@ -1532,6 +1750,9 @@ class FlaxBigBirdModule(nn.Module):
token_type_ids, token_type_ids,
position_ids, position_ids,
head_mask, head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
...@@ -1545,6 +1766,9 @@ class FlaxBigBirdModule(nn.Module): ...@@ -1545,6 +1766,9 @@ class FlaxBigBirdModule(nn.Module):
attention_mask, attention_mask,
head_mask=head_mask, head_mask=head_mask,
deterministic=deterministic, deterministic=deterministic,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -1559,11 +1783,12 @@ class FlaxBigBirdModule(nn.Module): ...@@ -1559,11 +1783,12 @@ class FlaxBigBirdModule(nn.Module):
return (hidden_states,) + outputs[1:] return (hidden_states,) + outputs[1:]
return (hidden_states, pooled) + outputs[1:] return (hidden_states, pooled) + outputs[1:]
return FlaxBaseModelOutputWithPooling( return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
pooler_output=pooled, pooler_output=pooled,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
) )
...@@ -2181,3 +2406,110 @@ append_call_sample_docstring( ...@@ -2181,3 +2406,110 @@ append_call_sample_docstring(
FlaxBigBirdForQuestionAnsweringModelOutput, FlaxBigBirdForQuestionAnsweringModelOutput,
_CONFIG_FOR_DOC, _CONFIG_FOR_DOC,
) )
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLMModule with Bert->BigBird
class FlaxBigBirdForCausalLMModule(nn.Module):
config: BigBirdConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.bert = FlaxBigBirdModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
token_type_ids: Optional[jnp.ndarray] = None,
head_mask: Optional[jnp.ndarray] = None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.bert(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
else:
shared_embedding = None
# Compute the prediction scores
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxCausalLMOutputWithCrossAttentions(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
@add_start_docstrings(
"""
BigBird Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
autoregressive tasks.
""",
BIG_BIRD_START_DOCSTRING,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLM with Bert->BigBird
class FlaxBigBirdForCausalLM(FlaxBigBirdPreTrainedModel):
module_class = FlaxBigBirdForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
past_key_values = self.init_cache(batch_size, max_length)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since the decoder uses a causal mask, those positions are masked anyway.
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
position_ids = attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
"position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
return model_kwargs
append_call_sample_docstring(
FlaxBigBirdForCausalLM,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxCausalLMOutputWithCrossAttentions,
_CONFIG_FOR_DOC,
)
...@@ -59,6 +59,7 @@ if is_tf_available(): ...@@ -59,6 +59,7 @@ if is_tf_available():
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_electra"] = [ _import_structure["modeling_flax_electra"] = [
"FlaxElectraForCausalLM",
"FlaxElectraForMaskedLM", "FlaxElectraForMaskedLM",
"FlaxElectraForMultipleChoice", "FlaxElectraForMultipleChoice",
"FlaxElectraForPreTraining", "FlaxElectraForPreTraining",
...@@ -107,6 +108,7 @@ if TYPE_CHECKING: ...@@ -107,6 +108,7 @@ if TYPE_CHECKING:
if is_flax_available(): if is_flax_available():
from .modeling_flax_electra import ( from .modeling_flax_electra import (
FlaxElectraForCausalLM,
FlaxElectraForMaskedLM, FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice, FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining, FlaxElectraForPreTraining,
......
...@@ -22,13 +22,15 @@ import flax.linen as nn ...@@ -22,13 +22,15 @@ import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from jax.random import PRNGKey
from ...modeling_flax_outputs import ( from ...modeling_flax_outputs import (
FlaxBaseModelOutput, FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
FlaxMaskedLMOutput, FlaxMaskedLMOutput,
FlaxMultipleChoiceModelOutput, FlaxMultipleChoiceModelOutput,
FlaxQuestionAnsweringModelOutput, FlaxQuestionAnsweringModelOutput,
...@@ -184,9 +186,11 @@ class FlaxElectraEmbeddings(nn.Module): ...@@ -184,9 +186,11 @@ class FlaxElectraEmbeddings(nn.Module):
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Electra # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Electra
class FlaxElectraSelfAttention(nn.Module): class FlaxElectraSelfAttention(nn.Module):
config: ElectraConfig config: ElectraConfig
causal: bool = False
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self): def setup(self):
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
if self.config.hidden_size % self.config.num_attention_heads != 0: if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError( raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
...@@ -209,30 +213,113 @@ class FlaxElectraSelfAttention(nn.Module): ...@@ -209,30 +213,113 @@ class FlaxElectraSelfAttention(nn.Module):
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
if self.causal:
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
)
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
@nn.compact
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slighly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# update key, value caches with our new 1d spatial slices
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
def __call__( def __call__(
self, self,
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
key_value_states: Optional[jnp.array] = None,
init_cache: bool = False,
deterministic=True, deterministic=True,
output_attentions: bool = False, output_attentions: bool = False,
): ):
head_dim = self.config.hidden_size // self.config.num_attention_heads # if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
query_states = self.query(hidden_states).reshape( is_cross_attention = key_value_states is not None
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) batch_size = hidden_states.shape[0]
)
value_states = self.value(hidden_states).reshape( # get query proj
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) query_states = self.query(hidden_states)
# get key, value proj
if is_cross_attention:
# cross_attentions
key_states = self.key(key_value_states)
value_states = self.value(key_value_states)
else:
# self_attention
key_states = self.key(hidden_states)
value_states = self.value(hidden_states)
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
# handle cache prepare causal attention mask
if self.causal:
query_length, key_length = query_states.shape[1], key_states.shape[1]
if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
causal_mask = lax.dynamic_slice(
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
) )
key_states = self.key(hidden_states).reshape( else:
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) causal_mask = self.causal_mask[:, :, :query_length, :key_length]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
# combine masks if needed
if attention_mask is not None and self.causal:
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask)
elif self.causal:
attention_mask = causal_mask
elif attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
) )
# Convert the boolean attention mask to an attention bias. # Convert the boolean attention mask to an attention bias.
if attention_mask is not None: if attention_mask is not None:
# attention mask in the form of attention bias # attention mask in the form of attention bias
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
...@@ -292,10 +379,11 @@ class FlaxElectraSelfOutput(nn.Module): ...@@ -292,10 +379,11 @@ class FlaxElectraSelfOutput(nn.Module):
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Electra # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Electra
class FlaxElectraAttention(nn.Module): class FlaxElectraAttention(nn.Module):
config: ElectraConfig config: ElectraConfig
causal: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
self.self = FlaxElectraSelfAttention(self.config, dtype=self.dtype) self.self = FlaxElectraSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
self.output = FlaxElectraSelfOutput(self.config, dtype=self.dtype) self.output = FlaxElectraSelfOutput(self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -303,6 +391,8 @@ class FlaxElectraAttention(nn.Module): ...@@ -303,6 +391,8 @@ class FlaxElectraAttention(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
key_value_states=None,
init_cache=False,
deterministic=True, deterministic=True,
output_attentions: bool = False, output_attentions: bool = False,
): ):
...@@ -313,6 +403,8 @@ class FlaxElectraAttention(nn.Module): ...@@ -313,6 +403,8 @@ class FlaxElectraAttention(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
key_value_states=key_value_states,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -373,27 +465,46 @@ class FlaxElectraLayer(nn.Module): ...@@ -373,27 +465,46 @@ class FlaxElectraLayer(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self): def setup(self):
self.attention = FlaxElectraAttention(self.config, dtype=self.dtype) self.attention = FlaxElectraAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
self.intermediate = FlaxElectraIntermediate(self.config, dtype=self.dtype) self.intermediate = FlaxElectraIntermediate(self.config, dtype=self.dtype)
self.output = FlaxElectraOutput(self.config, dtype=self.dtype) self.output = FlaxElectraOutput(self.config, dtype=self.dtype)
if self.config.add_cross_attention:
self.crossattention = FlaxElectraAttention(self.config, causal=False, dtype=self.dtype)
def __call__( def __call__(
self, self,
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
): ):
# Self Attention
attention_outputs = self.attention( attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
attention_output = attention_outputs[0] attention_output = attention_outputs[0]
# Cross-Attention Block
if encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask,
key_value_states=encoder_hidden_states,
deterministic=deterministic,
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
hidden_states = self.intermediate(attention_output) hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
...@@ -401,6 +512,8 @@ class FlaxElectraLayer(nn.Module): ...@@ -401,6 +512,8 @@ class FlaxElectraLayer(nn.Module):
if output_attentions: if output_attentions:
outputs += (attention_outputs[1],) outputs += (attention_outputs[1],)
if encoder_hidden_states is not None:
outputs += (cross_attention_outputs[1],)
return outputs return outputs
...@@ -419,6 +532,9 @@ class FlaxElectraLayerCollection(nn.Module): ...@@ -419,6 +532,9 @@ class FlaxElectraLayerCollection(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask, head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
...@@ -426,6 +542,7 @@ class FlaxElectraLayerCollection(nn.Module): ...@@ -426,6 +542,7 @@ class FlaxElectraLayerCollection(nn.Module):
): ):
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# Check if head_mask has a correct number of layers specified if desired # Check if head_mask has a correct number of layers specified if desired
if head_mask is not None: if head_mask is not None:
...@@ -443,6 +560,9 @@ class FlaxElectraLayerCollection(nn.Module): ...@@ -443,6 +560,9 @@ class FlaxElectraLayerCollection(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None, layer_head_mask=head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -452,6 +572,9 @@ class FlaxElectraLayerCollection(nn.Module): ...@@ -452,6 +572,9 @@ class FlaxElectraLayerCollection(nn.Module):
if output_attentions: if output_attentions:
all_attentions += (layer_outputs[1],) all_attentions += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
...@@ -460,8 +583,11 @@ class FlaxElectraLayerCollection(nn.Module): ...@@ -460,8 +583,11 @@ class FlaxElectraLayerCollection(nn.Module):
if not return_dict: if not return_dict:
return tuple(v for v in outputs if v is not None) return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput( return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
) )
...@@ -478,6 +604,9 @@ class FlaxElectraEncoder(nn.Module): ...@@ -478,6 +604,9 @@ class FlaxElectraEncoder(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask, head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
...@@ -487,6 +616,9 @@ class FlaxElectraEncoder(nn.Module): ...@@ -487,6 +616,9 @@ class FlaxElectraEncoder(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -548,6 +680,7 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): ...@@ -548,6 +680,7 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
...@@ -559,9 +692,26 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): ...@@ -559,9 +692,26 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init( if self.config.add_cross_attention:
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
encoder_attention_mask = attention_mask
module_init_outputs = self.module.init(
rngs,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
return_dict=False,
)
else:
module_init_outputs = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)["params"] )
random_params = module_init_outputs["params"]
if params is not None: if params is not None:
random_params = flatten_dict(unfreeze(random_params)) random_params = flatten_dict(unfreeze(random_params))
...@@ -573,6 +723,26 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): ...@@ -573,6 +723,26 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
else: else:
return random_params return random_params
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
def init_cache(self, batch_size, max_length):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
"""
# init input variables to retrieve cache
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
attention_mask = jnp.ones_like(input_ids, dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
)
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
self, self,
...@@ -581,12 +751,15 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): ...@@ -581,12 +751,15 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
params: dict = None, params: dict = None,
dropout_rng: PRNGKey = None, dropout_rng: jax.random.PRNGKey = None,
train: bool = False, train: bool = False,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
past_key_values: dict = None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
...@@ -613,20 +786,61 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): ...@@ -613,20 +786,61 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
if dropout_rng is not None: if dropout_rng is not None:
rngs["dropout"] = dropout_rng rngs["dropout"] = dropout_rng
return self.module.apply( inputs = {"params": params or self.params}
{"params": params or self.params},
if self.config.add_cross_attention:
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
# changed by FlaxElectraAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
head_mask=jnp.array(head_mask, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
deterministic=not train,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
rngs=rngs,
mutable=mutable,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past_key_values = outputs
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past_key_values = outputs
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
else:
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"), jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"), jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"), token_type_ids=jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"), position_ids=jnp.array(position_ids, dtype="i4"),
jnp.array(head_mask, dtype="i4"), head_mask=jnp.array(head_mask, dtype="i4"),
not train, deterministic=not train,
output_attentions, output_attentions=output_attentions,
output_hidden_states, output_hidden_states=output_hidden_states,
return_dict, return_dict=return_dict,
rngs=rngs, rngs=rngs,
) )
return outputs
class FlaxElectraModule(nn.Module): class FlaxElectraModule(nn.Module):
config: ElectraConfig config: ElectraConfig
...@@ -645,6 +859,9 @@ class FlaxElectraModule(nn.Module): ...@@ -645,6 +859,9 @@ class FlaxElectraModule(nn.Module):
token_type_ids, token_type_ids,
position_ids, position_ids,
head_mask: Optional[np.ndarray] = None, head_mask: Optional[np.ndarray] = None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
...@@ -661,6 +878,9 @@ class FlaxElectraModule(nn.Module): ...@@ -661,6 +878,9 @@ class FlaxElectraModule(nn.Module):
attention_mask, attention_mask,
head_mask=head_mask, head_mask=head_mask,
deterministic=deterministic, deterministic=deterministic,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -1232,3 +1452,111 @@ append_call_sample_docstring( ...@@ -1232,3 +1452,111 @@ append_call_sample_docstring(
FlaxSequenceClassifierOutput, FlaxSequenceClassifierOutput,
_CONFIG_FOR_DOC, _CONFIG_FOR_DOC,
) )
class FlaxElectraForCausalLMModule(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
if self.config.tie_word_embeddings:
self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
else:
self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask: Optional[jnp.ndarray] = None,
token_type_ids: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
head_mask: Optional[jnp.ndarray] = None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
outputs = self.electra(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
prediction_scores = self.generator_predictions(hidden_states)
if self.config.tie_word_embeddings:
shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T)
else:
prediction_scores = self.generator_lm_head(prediction_scores)
if not return_dict:
return (prediction_scores,) + outputs[1:]
return FlaxCausalLMOutputWithCrossAttentions(
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
@add_start_docstrings(
"""
Electra Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
autoregressive tasks.
""",
ELECTRA_START_DOCSTRING,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLM with Bert->Electra
class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel):
module_class = FlaxElectraForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
past_key_values = self.init_cache(batch_size, max_length)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since the decoder uses a causal mask, those positions are masked anyway.
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
position_ids = attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
"position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
return model_kwargs
append_call_sample_docstring(
FlaxElectraForCausalLM,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxCausalLMOutputWithCrossAttentions,
_CONFIG_FOR_DOC,
)
...@@ -58,6 +58,7 @@ if is_tf_available(): ...@@ -58,6 +58,7 @@ if is_tf_available():
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_roberta"] = [ _import_structure["modeling_flax_roberta"] = [
"FlaxRobertaForCausalLM",
"FlaxRobertaForMaskedLM", "FlaxRobertaForMaskedLM",
"FlaxRobertaForMultipleChoice", "FlaxRobertaForMultipleChoice",
"FlaxRobertaForQuestionAnswering", "FlaxRobertaForQuestionAnswering",
...@@ -103,7 +104,8 @@ if TYPE_CHECKING: ...@@ -103,7 +104,8 @@ if TYPE_CHECKING:
) )
if is_flax_available(): if is_flax_available():
from .modeling_tf_roberta import ( from .modeling_flax_roberta import (
FlaxRobertaForCausalLM,
FlaxRobertaForMaskedLM, FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice, FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering, FlaxRobertaForQuestionAnswering,
......
...@@ -20,14 +20,16 @@ import flax.linen as nn ...@@ -20,14 +20,16 @@ import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax from jax import lax
from jax.random import PRNGKey
from ...modeling_flax_outputs import ( from ...modeling_flax_outputs import (
FlaxBaseModelOutput, FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxBaseModelOutputWithPooling, FlaxBaseModelOutputWithPooling,
FlaxBaseModelOutputWithPoolingAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
FlaxMaskedLMOutput, FlaxMaskedLMOutput,
FlaxMultipleChoiceModelOutput, FlaxMultipleChoiceModelOutput,
FlaxQuestionAnsweringModelOutput, FlaxQuestionAnsweringModelOutput,
...@@ -174,9 +176,11 @@ class FlaxRobertaEmbeddings(nn.Module): ...@@ -174,9 +176,11 @@ class FlaxRobertaEmbeddings(nn.Module):
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta
class FlaxRobertaSelfAttention(nn.Module): class FlaxRobertaSelfAttention(nn.Module):
config: RobertaConfig config: RobertaConfig
causal: bool = False
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self): def setup(self):
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
if self.config.hidden_size % self.config.num_attention_heads != 0: if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError( raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
...@@ -199,30 +203,113 @@ class FlaxRobertaSelfAttention(nn.Module): ...@@ -199,30 +203,113 @@ class FlaxRobertaSelfAttention(nn.Module):
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
if self.causal:
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
)
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
@nn.compact
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slighly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# update key, value caches with our new 1d spatial slices
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
def __call__( def __call__(
self, self,
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
key_value_states: Optional[jnp.array] = None,
init_cache: bool = False,
deterministic=True, deterministic=True,
output_attentions: bool = False, output_attentions: bool = False,
): ):
head_dim = self.config.hidden_size // self.config.num_attention_heads # if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
query_states = self.query(hidden_states).reshape( is_cross_attention = key_value_states is not None
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) batch_size = hidden_states.shape[0]
)
value_states = self.value(hidden_states).reshape( # get query proj
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) query_states = self.query(hidden_states)
# get key, value proj
if is_cross_attention:
# cross_attentions
key_states = self.key(key_value_states)
value_states = self.value(key_value_states)
else:
# self_attention
key_states = self.key(hidden_states)
value_states = self.value(hidden_states)
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
# handle cache prepare causal attention mask
if self.causal:
query_length, key_length = query_states.shape[1], key_states.shape[1]
if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
causal_mask = lax.dynamic_slice(
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
) )
key_states = self.key(hidden_states).reshape( else:
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) causal_mask = self.causal_mask[:, :, :query_length, :key_length]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
# combine masks if needed
if attention_mask is not None and self.causal:
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask)
elif self.causal:
attention_mask = causal_mask
elif attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
) )
# Convert the boolean attention mask to an attention bias. # Convert the boolean attention mask to an attention bias.
if attention_mask is not None: if attention_mask is not None:
# attention mask in the form of attention bias # attention mask in the form of attention bias
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
...@@ -282,10 +369,11 @@ class FlaxRobertaSelfOutput(nn.Module): ...@@ -282,10 +369,11 @@ class FlaxRobertaSelfOutput(nn.Module):
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
class FlaxRobertaAttention(nn.Module): class FlaxRobertaAttention(nn.Module):
config: RobertaConfig config: RobertaConfig
causal: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype) self.self = FlaxRobertaSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype) self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
def __call__( def __call__(
...@@ -293,6 +381,8 @@ class FlaxRobertaAttention(nn.Module): ...@@ -293,6 +381,8 @@ class FlaxRobertaAttention(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
key_value_states=None,
init_cache=False,
deterministic=True, deterministic=True,
output_attentions: bool = False, output_attentions: bool = False,
): ):
...@@ -303,6 +393,8 @@ class FlaxRobertaAttention(nn.Module): ...@@ -303,6 +393,8 @@ class FlaxRobertaAttention(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
key_value_states=key_value_states,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -363,27 +455,46 @@ class FlaxRobertaLayer(nn.Module): ...@@ -363,27 +455,46 @@ class FlaxRobertaLayer(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self): def setup(self):
self.attention = FlaxRobertaAttention(self.config, dtype=self.dtype) self.attention = FlaxRobertaAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype) self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
self.output = FlaxRobertaOutput(self.config, dtype=self.dtype) self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
if self.config.add_cross_attention:
self.crossattention = FlaxRobertaAttention(self.config, causal=False, dtype=self.dtype)
def __call__( def __call__(
self, self,
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
): ):
# Self Attention
attention_outputs = self.attention( attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
attention_output = attention_outputs[0] attention_output = attention_outputs[0]
# Cross-Attention Block
if encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask,
key_value_states=encoder_hidden_states,
deterministic=deterministic,
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
hidden_states = self.intermediate(attention_output) hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
...@@ -391,6 +502,8 @@ class FlaxRobertaLayer(nn.Module): ...@@ -391,6 +502,8 @@ class FlaxRobertaLayer(nn.Module):
if output_attentions: if output_attentions:
outputs += (attention_outputs[1],) outputs += (attention_outputs[1],)
if encoder_hidden_states is not None:
outputs += (cross_attention_outputs[1],)
return outputs return outputs
...@@ -409,6 +522,9 @@ class FlaxRobertaLayerCollection(nn.Module): ...@@ -409,6 +522,9 @@ class FlaxRobertaLayerCollection(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask, head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
...@@ -416,6 +532,7 @@ class FlaxRobertaLayerCollection(nn.Module): ...@@ -416,6 +532,7 @@ class FlaxRobertaLayerCollection(nn.Module):
): ):
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# Check if head_mask has a correct number of layers specified if desired # Check if head_mask has a correct number of layers specified if desired
if head_mask is not None: if head_mask is not None:
...@@ -433,6 +550,9 @@ class FlaxRobertaLayerCollection(nn.Module): ...@@ -433,6 +550,9 @@ class FlaxRobertaLayerCollection(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None, layer_head_mask=head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -442,6 +562,9 @@ class FlaxRobertaLayerCollection(nn.Module): ...@@ -442,6 +562,9 @@ class FlaxRobertaLayerCollection(nn.Module):
if output_attentions: if output_attentions:
all_attentions += (layer_outputs[1],) all_attentions += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
...@@ -450,8 +573,11 @@ class FlaxRobertaLayerCollection(nn.Module): ...@@ -450,8 +573,11 @@ class FlaxRobertaLayerCollection(nn.Module):
if not return_dict: if not return_dict:
return tuple(v for v in outputs if v is not None) return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput( return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
) )
...@@ -468,6 +594,9 @@ class FlaxRobertaEncoder(nn.Module): ...@@ -468,6 +594,9 @@ class FlaxRobertaEncoder(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask, head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
...@@ -477,6 +606,9 @@ class FlaxRobertaEncoder(nn.Module): ...@@ -477,6 +606,9 @@ class FlaxRobertaEncoder(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -603,9 +735,26 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): ...@@ -603,9 +735,26 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init( if self.config.add_cross_attention:
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
encoder_attention_mask = attention_mask
module_init_outputs = self.module.init(
rngs,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
return_dict=False,
)
else:
module_init_outputs = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)["params"] )
random_params = module_init_outputs["params"]
if params is not None: if params is not None:
random_params = flatten_dict(unfreeze(random_params)) random_params = flatten_dict(unfreeze(random_params))
...@@ -617,6 +766,26 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): ...@@ -617,6 +766,26 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
else: else:
return random_params return random_params
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
def init_cache(self, batch_size, max_length):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
"""
# init input variables to retrieve cache
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
attention_mask = jnp.ones_like(input_ids, dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
)
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
self, self,
...@@ -625,12 +794,15 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): ...@@ -625,12 +794,15 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
params: dict = None, params: dict = None,
dropout_rng: PRNGKey = None, dropout_rng: jax.random.PRNGKey = None,
train: bool = False, train: bool = False,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
past_key_values: dict = None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -656,20 +828,61 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): ...@@ -656,20 +828,61 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
if dropout_rng is not None: if dropout_rng is not None:
rngs["dropout"] = dropout_rng rngs["dropout"] = dropout_rng
return self.module.apply( inputs = {"params": params or self.params}
{"params": params or self.params},
if self.config.add_cross_attention:
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
# changed by FlaxRobertaAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
head_mask=jnp.array(head_mask, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
deterministic=not train,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
rngs=rngs,
mutable=mutable,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past_key_values = outputs
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past_key_values = outputs
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
else:
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"), jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"), jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"), token_type_ids=jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"), position_ids=jnp.array(position_ids, dtype="i4"),
jnp.array(head_mask, dtype="i4"), head_mask=jnp.array(head_mask, dtype="i4"),
not train, deterministic=not train,
output_attentions, output_attentions=output_attentions,
output_hidden_states, output_hidden_states=output_hidden_states,
return_dict, return_dict=return_dict,
rngs=rngs, rngs=rngs,
) )
return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
class FlaxRobertaModule(nn.Module): class FlaxRobertaModule(nn.Module):
...@@ -686,9 +899,12 @@ class FlaxRobertaModule(nn.Module): ...@@ -686,9 +899,12 @@ class FlaxRobertaModule(nn.Module):
self, self,
input_ids, input_ids,
attention_mask, attention_mask,
token_type_ids: Optional[np.ndarray] = None, token_type_ids: Optional[jnp.ndarray] = None,
position_ids: Optional[np.ndarray] = None, position_ids: Optional[jnp.ndarray] = None,
head_mask: Optional[np.ndarray] = None, head_mask: Optional[jnp.ndarray] = None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
...@@ -710,6 +926,9 @@ class FlaxRobertaModule(nn.Module): ...@@ -710,6 +926,9 @@ class FlaxRobertaModule(nn.Module):
attention_mask, attention_mask,
head_mask=head_mask, head_mask=head_mask,
deterministic=deterministic, deterministic=deterministic,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -723,11 +942,12 @@ class FlaxRobertaModule(nn.Module): ...@@ -723,11 +942,12 @@ class FlaxRobertaModule(nn.Module):
return (hidden_states,) + outputs[1:] return (hidden_states,) + outputs[1:]
return (hidden_states, pooled) + outputs[1:] return (hidden_states, pooled) + outputs[1:]
return FlaxBaseModelOutputWithPooling( return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
pooler_output=pooled, pooler_output=pooled,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
) )
...@@ -1101,3 +1321,108 @@ append_call_sample_docstring( ...@@ -1101,3 +1321,108 @@ append_call_sample_docstring(
FlaxQuestionAnsweringModelOutput, FlaxQuestionAnsweringModelOutput,
_CONFIG_FOR_DOC, _CONFIG_FOR_DOC,
) )
class FlaxRobertaForCausalLMModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
token_type_ids: Optional[jnp.ndarray] = None,
head_mask: Optional[jnp.ndarray] = None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.roberta(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
else:
shared_embedding = None
# Compute the prediction scores
logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxCausalLMOutputWithCrossAttentions(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
@add_start_docstrings(
"""
Roberta Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
autoregressive tasks.
""",
ROBERTA_START_DOCSTRING,
)
class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel):
module_class = FlaxRobertaForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
past_key_values = self.init_cache(batch_size, max_length)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since the decoder uses a causal mask, those positions are masked anyway.
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
position_ids = attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
"position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
return model_kwargs
append_call_sample_docstring(
FlaxRobertaForCausalLM,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxCausalLMOutputWithCrossAttentions,
_CONFIG_FOR_DOC,
)
...@@ -326,6 +326,13 @@ class FlaxBeitPreTrainedModel(metaclass=DummyObject): ...@@ -326,6 +326,13 @@ class FlaxBeitPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxBertForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxBertForMaskedLM(metaclass=DummyObject): class FlaxBertForMaskedLM(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
...@@ -389,6 +396,13 @@ class FlaxBertPreTrainedModel(metaclass=DummyObject): ...@@ -389,6 +396,13 @@ class FlaxBertPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxBigBirdForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxBigBirdForMaskedLM(metaclass=DummyObject): class FlaxBigBirdForMaskedLM(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
...@@ -578,6 +592,13 @@ class FlaxDistilBertPreTrainedModel(metaclass=DummyObject): ...@@ -578,6 +592,13 @@ class FlaxDistilBertPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxElectraForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxElectraForMaskedLM(metaclass=DummyObject): class FlaxElectraForMaskedLM(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
...@@ -795,6 +816,13 @@ class FlaxPegasusPreTrainedModel(metaclass=DummyObject): ...@@ -795,6 +816,13 @@ class FlaxPegasusPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxRobertaForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRobertaForMaskedLM(metaclass=DummyObject): class FlaxRobertaForMaskedLM(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
......
...@@ -24,15 +24,17 @@ import flax.linen as nn ...@@ -24,15 +24,17 @@ import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
from flax.linen import combine_masks, make_causal_mask
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from jax import lax from jax import lax
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import ( from ...modeling_flax_outputs import (
FlaxBaseModelOutput, FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxBaseModelOutputWithPooling, FlaxBaseModelOutputWithPoolingAndCrossAttentions,
FlaxCausalLMOutput, FlaxCausalLMOutput,
FlaxCausalLMOutputWithCrossAttentions,
FlaxMaskedLMOutput, FlaxMaskedLMOutput,
FlaxMultipleChoiceModelOutput, FlaxMultipleChoiceModelOutput,
FlaxQuestionAnsweringModelOutput, FlaxQuestionAnsweringModelOutput,
...@@ -170,9 +172,11 @@ class Flax{{cookiecutter.camelcase_modelname}}Embeddings(nn.Module): ...@@ -170,9 +172,11 @@ class Flax{{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module): class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config config: {{cookiecutter.camelcase_modelname}}Config
causal: bool = False
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self): def setup(self):
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
if self.config.hidden_size % self.config.num_attention_heads != 0: if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError( raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
...@@ -195,30 +199,113 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module): ...@@ -195,30 +199,113 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
if self.causal:
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
)
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
@nn.compact
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slighly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# update key, value caches with our new 1d spatial slices
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
def __call__( def __call__(
self, self,
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
key_value_states: Optional[jnp.array] = None,
init_cache: bool = False,
deterministic=True, deterministic=True,
output_attentions: bool = False output_attentions: bool = False,
): ):
head_dim = self.config.hidden_size // self.config.num_attention_heads # if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
batch_size = hidden_states.shape[0]
query_states = self.query(hidden_states).reshape( # get query proj
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) query_states = self.query(hidden_states)
) # get key, value proj
value_states = self.value(hidden_states).reshape( if is_cross_attention:
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) # cross_attentions
key_states = self.key(key_value_states)
value_states = self.value(key_value_states)
else:
# self_attention
key_states = self.key(hidden_states)
value_states = self.value(hidden_states)
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
# handle cache prepare causal attention mask
if self.causal:
query_length, key_length = query_states.shape[1], key_states.shape[1]
if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
causal_mask = lax.dynamic_slice(
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
) )
key_states = self.key(hidden_states).reshape( else:
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) causal_mask = self.causal_mask[:, :, :query_length, :key_length]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
# combine masks if needed
if attention_mask is not None and self.causal:
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask)
elif self.causal:
attention_mask = causal_mask
elif attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
) )
# Convert the boolean attention mask to an attention bias. # Convert the boolean attention mask to an attention bias.
if attention_mask is not None: if attention_mask is not None:
# attention mask in the form of attention bias # attention mask in the form of attention bias
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
...@@ -278,6 +365,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module): ...@@ -278,6 +365,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module):
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module): class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config config: {{cookiecutter.camelcase_modelname}}Config
causal: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
...@@ -289,6 +377,8 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module): ...@@ -289,6 +377,8 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
key_value_states=None,
init_cache=False,
deterministic=True, deterministic=True,
output_attentions: bool = False, output_attentions: bool = False,
): ):
...@@ -299,6 +389,8 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module): ...@@ -299,6 +389,8 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
key_value_states=key_value_states,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -362,24 +454,43 @@ class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module): ...@@ -362,24 +454,43 @@ class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module):
self.attention = Flax{{cookiecutter.camelcase_modelname}}Attention(self.config, dtype=self.dtype) self.attention = Flax{{cookiecutter.camelcase_modelname}}Attention(self.config, dtype=self.dtype)
self.intermediate = Flax{{cookiecutter.camelcase_modelname}}Intermediate(self.config, dtype=self.dtype) self.intermediate = Flax{{cookiecutter.camelcase_modelname}}Intermediate(self.config, dtype=self.dtype)
self.output = Flax{{cookiecutter.camelcase_modelname}}Output(self.config, dtype=self.dtype) self.output = Flax{{cookiecutter.camelcase_modelname}}Output(self.config, dtype=self.dtype)
if self.config.add_cross_attention:
self.crossattention = Flax{{cookiecutter.camelcase_modelname}}Attention(self.config, causal=False, dtype=self.dtype)
def __call__( def __call__(
self, self,
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
): ):
# Self Attention
attention_outputs = self.attention( attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
attention_output = attention_outputs[0] attention_output = attention_outputs[0]
# Cross-Attention Block
if encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask,
key_value_states=encoder_hidden_states,
deterministic=deterministic,
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
hidden_states = self.intermediate(attention_output) hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
...@@ -387,6 +498,8 @@ class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module): ...@@ -387,6 +498,8 @@ class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module):
if output_attentions: if output_attentions:
outputs += (attention_outputs[1],) outputs += (attention_outputs[1],)
if encoder_hidden_states is not None:
outputs += (cross_attention_outputs[1],)
return outputs return outputs
...@@ -405,6 +518,9 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module): ...@@ -405,6 +518,9 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask, head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
...@@ -412,6 +528,7 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module): ...@@ -412,6 +528,7 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
): ):
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# Check if head_mask has a correct number of layers specified if desired # Check if head_mask has a correct number of layers specified if desired
if head_mask is not None: if head_mask is not None:
...@@ -429,6 +546,9 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module): ...@@ -429,6 +546,9 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None, layer_head_mask=head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -438,6 +558,9 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module): ...@@ -438,6 +558,9 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
if output_attentions: if output_attentions:
all_attentions += (layer_outputs[1],) all_attentions += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
...@@ -446,8 +569,11 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module): ...@@ -446,8 +569,11 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
if not return_dict: if not return_dict:
return tuple(v for v in outputs if v is not None) return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput( return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
) )
...@@ -464,6 +590,9 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module): ...@@ -464,6 +590,9 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask, head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
...@@ -473,6 +602,9 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module): ...@@ -473,6 +602,9 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic, deterministic=deterministic,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -598,6 +730,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode ...@@ -598,6 +730,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights with Bert->{{cookiecutter.camelcase_modelname}}
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
...@@ -609,9 +742,26 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode ...@@ -609,9 +742,26 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init( if self.config.add_cross_attention:
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
encoder_attention_mask = attention_mask
module_init_outputs = self.module.init(
rngs,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
return_dict=False,
)
else:
module_init_outputs = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)["params"] )
random_params = module_init_outputs["params"]
if params is not None: if params is not None:
random_params = flatten_dict(unfreeze(random_params)) random_params = flatten_dict(unfreeze(random_params))
...@@ -623,7 +773,29 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode ...@@ -623,7 +773,29 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
else: else:
return random_params return random_params
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_cache with Bert->{{cookiecutter.camelcase_modelname}}
def init_cache(self, batch_size, max_length):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
"""
# init input variables to retrieve cache
input_ids = jnp.ones((batch_size, max_length))
attention_mask = jnp.ones_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
)
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.__call__ with Bert->{{cookiecutter.camelcase_modelname}}
def __call__( def __call__(
self, self,
input_ids, input_ids,
...@@ -631,12 +803,15 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode ...@@ -631,12 +803,15 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
params: dict = None, params: dict = None,
dropout_rng: jax.random.PRNGKey = None, dropout_rng: jax.random.PRNGKey = None,
train: bool = False, train: bool = False,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
past_key_values: dict = None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -662,20 +837,61 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode ...@@ -662,20 +837,61 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
if dropout_rng is not None: if dropout_rng is not None:
rngs["dropout"] = dropout_rng rngs["dropout"] = dropout_rng
return self.module.apply( inputs = {"params": params or self.params}
{"params": params or self.params},
if self.config.add_cross_attention:
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
# changed by FlaxBertAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"), jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"), jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"), token_type_ids=jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"), position_ids=jnp.array(position_ids, dtype="i4"),
jnp.array(head_mask, dtype="i4"), head_mask=jnp.array(head_mask, dtype="i4"),
not train, encoder_hidden_states=encoder_hidden_states,
output_attentions, encoder_attention_mask=encoder_attention_mask,
output_hidden_states, deterministic=not train,
return_dict, output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
rngs=rngs,
mutable=mutable,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past_key_values = outputs
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past_key_values = outputs
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
else:
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
head_mask=jnp.array(head_mask, dtype="i4"),
deterministic=not train,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
rngs=rngs, rngs=rngs,
) )
return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module): class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config config: {{cookiecutter.camelcase_modelname}}Config
...@@ -691,14 +907,25 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module): ...@@ -691,14 +907,25 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
self, self,
input_ids, input_ids,
attention_mask, attention_mask,
token_type_ids, token_type_ids: Optional[jnp.ndarray] = None,
position_ids, position_ids: Optional[jnp.ndarray] = None,
head_mask, head_mask: Optional[jnp.ndarray] = None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
return_dict: bool = True, return_dict: bool = True,
): ):
# make sure `token_type_ids` is correctly initialized when not passed
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)
# make sure `position_ids` is correctly initialized when not passed
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
hidden_states = self.embeddings( hidden_states = self.embeddings(
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
) )
...@@ -707,6 +934,9 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module): ...@@ -707,6 +934,9 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
attention_mask, attention_mask,
head_mask=head_mask, head_mask=head_mask,
deterministic=deterministic, deterministic=deterministic,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -720,11 +950,12 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module): ...@@ -720,11 +950,12 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
return (hidden_states,) + outputs[1:] return (hidden_states,) + outputs[1:]
return (hidden_states, pooled) + outputs[1:] return (hidden_states, pooled) + outputs[1:]
return FlaxBaseModelOutputWithPooling( return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
pooler_output=pooled, pooler_output=pooled,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
) )
add_start_docstrings( add_start_docstrings(
...@@ -1137,6 +1368,112 @@ append_call_sample_docstring( ...@@ -1137,6 +1368,112 @@ append_call_sample_docstring(
FlaxQuestionAnsweringModelOutput, FlaxQuestionAnsweringModelOutput,
_CONFIG_FOR_DOC, _CONFIG_FOR_DOC,
) )
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
token_type_ids: Optional[jnp.ndarray] = None,
head_mask: Optional[jnp.ndarray] = None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Model
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.{{cookiecutter.lowercase_modelname}}.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
else:
shared_embedding = None
# Compute the prediction scores
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxCausalLMOutputWithCrossAttentions(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
@add_start_docstrings(
"""
{{cookiecutter.camelcase_modelname}} Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
autoregressive tasks.
""",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
)
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
module_class = Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
past_key_values = self.init_cache(batch_size, max_length)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since the decoder uses a causal mask, those positions are masked anyway.
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
position_ids = attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
"position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
return model_kwargs
append_call_sample_docstring(
Flax{{cookiecutter.camelcase_modelname}}ForCausalLM,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxCausalLMOutputWithCrossAttentions,
_CONFIG_FOR_DOC,
)
{# encoder_decoder #} {# encoder_decoder #}
{% else %} {% else %}
import math import math
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
from transformers import BertConfig, is_flax_available from transformers import BertConfig, is_flax_available
from transformers.testing_utils import require_flax, slow from transformers.testing_utils import require_flax, slow
from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_flax_available(): if is_flax_available():
...@@ -114,6 +114,22 @@ class FlaxBertModelTester(unittest.TestCase): ...@@ -114,6 +114,22 @@ class FlaxBertModelTester(unittest.TestCase):
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask} inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
return config, inputs_dict return config, inputs_dict
def prepare_config_and_inputs_for_decoder(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, token_type_ids, attention_mask = config_and_inputs
config.is_decoder = True
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
)
@require_flax @require_flax
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase): class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
......
...@@ -25,6 +25,7 @@ from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random ...@@ -25,6 +25,7 @@ from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random
if is_flax_available(): if is_flax_available():
import jax import jax
from transformers.models.big_bird.modeling_flax_big_bird import ( from transformers.models.big_bird.modeling_flax_big_bird import (
FlaxBigBirdForCausalLM,
FlaxBigBirdForMaskedLM, FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice, FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining, FlaxBigBirdForPreTraining,
...@@ -136,6 +137,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -136,6 +137,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
FlaxBigBirdForCausalLM,
FlaxBigBirdModel, FlaxBigBirdModel,
FlaxBigBirdForPreTraining, FlaxBigBirdForPreTraining,
FlaxBigBirdForMaskedLM, FlaxBigBirdForMaskedLM,
......
...@@ -10,6 +10,7 @@ from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random ...@@ -10,6 +10,7 @@ from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random
if is_flax_available(): if is_flax_available():
from transformers.models.electra.modeling_flax_electra import ( from transformers.models.electra.modeling_flax_electra import (
FlaxElectraForCausalLM,
FlaxElectraForMaskedLM, FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice, FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining, FlaxElectraForPreTraining,
...@@ -110,6 +111,7 @@ class FlaxElectraModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -110,6 +111,7 @@ class FlaxElectraModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
FlaxElectraModel, FlaxElectraModel,
FlaxElectraForCausalLM,
FlaxElectraForMaskedLM, FlaxElectraForMaskedLM,
FlaxElectraForPreTraining, FlaxElectraForPreTraining,
FlaxElectraForTokenClassification, FlaxElectraForTokenClassification,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment