Unverified Commit af1a10bf authored by Jayendra's avatar Jayendra Committed by GitHub
Browse files

[Flax] Return Attention from BERT, ELECTRA, RoBERTa and GPT2 (#11918)



* Added logic to return attention from flax-bert model and added test cases to check that

* Added new line at the end of file to test_modeling_flax_common.py

* fixing code style

* Fixing Roberta and Elextra models too from cpoying bert

* Added temporary hack to not run test_attention_outputs for FlaxGPT2

* Returning attention weights from GPT2 and changed the tests accordingly.

* last fixes

* bump flax dependency
Co-authored-by: default avatarjayendra <jayendra@infocusp.in>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent e1205e47
......@@ -97,7 +97,7 @@ _deps = [
"fastapi",
"filelock",
"flake8>=3.8.3",
"flax>=0.3.2",
"flax>=0.3.4",
"fugashi>=1.0",
"huggingface-hub==0.0.8",
"importlib_metadata",
......
......@@ -14,7 +14,7 @@ deps = {
"fastapi": "fastapi",
"filelock": "filelock",
"flake8": "flake8>=3.8.3",
"flax": "flax>=0.3.2",
"flax": "flax>=0.3.4",
"fugashi": "fugashi>=1.0",
"huggingface-hub": "huggingface-hub==0.0.8",
"importlib_metadata": "importlib_metadata",
......
......@@ -23,7 +23,7 @@ import jax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict
from flax.linen import dot_product_attention
from flax.linen.attention import dot_product_attention_weights
from jax import lax
from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward
......@@ -241,10 +241,9 @@ class FlaxBertSelfAttention(nn.Module):
if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
dropout_rng = self.make_rng("dropout")
attn_output = dot_product_attention(
attn_weights = dot_product_attention_weights(
query_states,
key_states,
value_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.config.attention_probs_dropout_prob,
......@@ -254,11 +253,10 @@ class FlaxBertSelfAttention(nn.Module):
precision=None,
)
outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),)
# TODO: at the moment it's not possible to retrieve attn_weights from
# dot_product_attention, but should be in the future -> add functionality then
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
......@@ -303,7 +301,7 @@ class FlaxBertAttention(nn.Module):
outputs = (hidden_states,)
if output_attentions:
outputs += attn_outputs[1]
outputs += (attn_outputs[1],)
return outputs
......@@ -396,7 +394,9 @@ class FlaxBertLayerCollection(nn.Module):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic)
layer_outputs = layer(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
......@@ -582,11 +582,6 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if output_attentions:
raise NotImplementedError(
"Currently attention scores cannot be returned. Please set `output_attentions` to False for now."
)
# init input tensors if not passed
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)
......
......@@ -23,7 +23,7 @@ import jax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict
from flax.linen import dot_product_attention
from flax.linen.attention import dot_product_attention_weights
from jax import lax
from jax.random import PRNGKey
......@@ -238,10 +238,9 @@ class FlaxElectraSelfAttention(nn.Module):
if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
dropout_rng = self.make_rng("dropout")
attn_output = dot_product_attention(
attn_weights = dot_product_attention_weights(
query_states,
key_states,
value_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.config.attention_probs_dropout_prob,
......@@ -251,11 +250,10 @@ class FlaxElectraSelfAttention(nn.Module):
precision=None,
)
outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),)
# TODO: at the moment it's not possible to retrieve attn_weights from
# dot_product_attention, but should be in the future -> add functionality then
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
......@@ -302,7 +300,7 @@ class FlaxElectraAttention(nn.Module):
outputs = (hidden_states,)
if output_attentions:
outputs += attn_outputs[1]
outputs += (attn_outputs[1],)
return outputs
......@@ -399,7 +397,9 @@ class FlaxElectraLayerCollection(nn.Module):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic)
layer_outputs = layer(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
......@@ -534,11 +534,6 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if output_attentions:
raise NotImplementedError(
"Currently attention scores cannot be returned. Please set `output_attentions` to False for now."
)
# init input tensors if not passed
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
......
......@@ -19,7 +19,8 @@ import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.linen import combine_masks, dot_product_attention, make_causal_mask
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from jax import lax
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
......@@ -215,10 +216,9 @@ class FlaxGPT2Attention(nn.Module):
)
# usual dot product attention
attn_output = dot_product_attention(
attn_weights = dot_product_attention_weights(
query,
key,
value,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.config.attn_pdrop,
......@@ -227,14 +227,13 @@ class FlaxGPT2Attention(nn.Module):
precision=None,
)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
attn_output = self._merge_heads(attn_output)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
# TODO: at the moment it's not possible to retrieve attn_weights from
# dot_product_attention, but should be in the future -> add functionality then
return (attn_output,)
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
class FlaxGPT2MLP(nn.Module):
......@@ -447,7 +446,13 @@ class FlaxGPT2BlockCollection(nn.Module):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = block(hidden_states, attention_mask, deterministic=deterministic, init_cache=init_cache)
layer_outputs = block(
hidden_states,
attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
......
......@@ -20,7 +20,7 @@ import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.linen import dot_product_attention
from flax.linen.attention import dot_product_attention_weights
from jax import lax
from jax.random import PRNGKey
......@@ -227,10 +227,9 @@ class FlaxRobertaSelfAttention(nn.Module):
if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
dropout_rng = self.make_rng("dropout")
attn_output = dot_product_attention(
attn_weights = dot_product_attention_weights(
query_states,
key_states,
value_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.config.attention_probs_dropout_prob,
......@@ -240,11 +239,10 @@ class FlaxRobertaSelfAttention(nn.Module):
precision=None,
)
outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),)
# TODO: at the moment it's not possible to retrieve attn_weights from
# dot_product_attention, but should be in the future -> add functionality then
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
......@@ -291,7 +289,7 @@ class FlaxRobertaAttention(nn.Module):
outputs = (hidden_states,)
if output_attentions:
outputs += attn_outputs[1]
outputs += (attn_outputs[1],)
return outputs
......@@ -388,7 +386,9 @@ class FlaxRobertaLayerCollection(nn.Module):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic)
layer_outputs = layer(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
......@@ -570,11 +570,6 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if output_attentions:
raise NotImplementedError(
"Currently attention scores cannot be returned." "Please set `output_attentions` to False for now."
)
# init input tensors if not passed
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)
......
......@@ -79,8 +79,9 @@ class FlaxModelTesterMixin:
if "ForMultipleChoice" in model_class.__name__:
inputs_dict = {
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
for k, v in inputs_dict.items()
if isinstance(v, (jax_xla.DeviceArray, np.ndarray))
else v
for k, v in inputs_dict.items()
}
return inputs_dict
......@@ -310,3 +311,48 @@ class FlaxModelTesterMixin:
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
seq_length = getattr(self.model_tester, "seq_length", None)
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_length, seq_length],
)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_length, seq_length],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment