Unverified Commit af1a10bf authored by Jayendra's avatar Jayendra Committed by GitHub
Browse files

[Flax] Return Attention from BERT, ELECTRA, RoBERTa and GPT2 (#11918)



* Added logic to return attention from flax-bert model and added test cases to check that

* Added new line at the end of file to test_modeling_flax_common.py

* fixing code style

* Fixing Roberta and Elextra models too from cpoying bert

* Added temporary hack to not run test_attention_outputs for FlaxGPT2

* Returning attention weights from GPT2 and changed the tests accordingly.

* last fixes

* bump flax dependency
Co-authored-by: default avatarjayendra <jayendra@infocusp.in>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent e1205e47
...@@ -97,7 +97,7 @@ _deps = [ ...@@ -97,7 +97,7 @@ _deps = [
"fastapi", "fastapi",
"filelock", "filelock",
"flake8>=3.8.3", "flake8>=3.8.3",
"flax>=0.3.2", "flax>=0.3.4",
"fugashi>=1.0", "fugashi>=1.0",
"huggingface-hub==0.0.8", "huggingface-hub==0.0.8",
"importlib_metadata", "importlib_metadata",
......
...@@ -14,7 +14,7 @@ deps = { ...@@ -14,7 +14,7 @@ deps = {
"fastapi": "fastapi", "fastapi": "fastapi",
"filelock": "filelock", "filelock": "filelock",
"flake8": "flake8>=3.8.3", "flake8": "flake8>=3.8.3",
"flax": "flax>=0.3.2", "flax": "flax>=0.3.4",
"fugashi": "fugashi>=1.0", "fugashi": "fugashi>=1.0",
"huggingface-hub": "huggingface-hub==0.0.8", "huggingface-hub": "huggingface-hub==0.0.8",
"importlib_metadata": "importlib_metadata", "importlib_metadata": "importlib_metadata",
......
...@@ -23,7 +23,7 @@ import jax ...@@ -23,7 +23,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.linen import dot_product_attention from flax.linen.attention import dot_product_attention_weights
from jax import lax from jax import lax
from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward
...@@ -241,10 +241,9 @@ class FlaxBertSelfAttention(nn.Module): ...@@ -241,10 +241,9 @@ class FlaxBertSelfAttention(nn.Module):
if not deterministic and self.config.attention_probs_dropout_prob > 0.0: if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
dropout_rng = self.make_rng("dropout") dropout_rng = self.make_rng("dropout")
attn_output = dot_product_attention( attn_weights = dot_product_attention_weights(
query_states, query_states,
key_states, key_states,
value_states,
bias=attention_bias, bias=attention_bias,
dropout_rng=dropout_rng, dropout_rng=dropout_rng,
dropout_rate=self.config.attention_probs_dropout_prob, dropout_rate=self.config.attention_probs_dropout_prob,
...@@ -254,11 +253,10 @@ class FlaxBertSelfAttention(nn.Module): ...@@ -254,11 +253,10 @@ class FlaxBertSelfAttention(nn.Module):
precision=None, precision=None,
) )
outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
# TODO: at the moment it's not possible to retrieve attn_weights from
# dot_product_attention, but should be in the future -> add functionality then
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs return outputs
...@@ -303,7 +301,7 @@ class FlaxBertAttention(nn.Module): ...@@ -303,7 +301,7 @@ class FlaxBertAttention(nn.Module):
outputs = (hidden_states,) outputs = (hidden_states,)
if output_attentions: if output_attentions:
outputs += attn_outputs[1] outputs += (attn_outputs[1],)
return outputs return outputs
...@@ -396,7 +394,9 @@ class FlaxBertLayerCollection(nn.Module): ...@@ -396,7 +394,9 @@ class FlaxBertLayerCollection(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic) layer_outputs = layer(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
...@@ -582,11 +582,6 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): ...@@ -582,11 +582,6 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.return_dict return_dict = return_dict if return_dict is not None else self.config.return_dict
if output_attentions:
raise NotImplementedError(
"Currently attention scores cannot be returned. Please set `output_attentions` to False for now."
)
# init input tensors if not passed # init input tensors if not passed
if token_type_ids is None: if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids) token_type_ids = jnp.zeros_like(input_ids)
......
...@@ -23,7 +23,7 @@ import jax ...@@ -23,7 +23,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.linen import dot_product_attention from flax.linen.attention import dot_product_attention_weights
from jax import lax from jax import lax
from jax.random import PRNGKey from jax.random import PRNGKey
...@@ -238,10 +238,9 @@ class FlaxElectraSelfAttention(nn.Module): ...@@ -238,10 +238,9 @@ class FlaxElectraSelfAttention(nn.Module):
if not deterministic and self.config.attention_probs_dropout_prob > 0.0: if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
dropout_rng = self.make_rng("dropout") dropout_rng = self.make_rng("dropout")
attn_output = dot_product_attention( attn_weights = dot_product_attention_weights(
query_states, query_states,
key_states, key_states,
value_states,
bias=attention_bias, bias=attention_bias,
dropout_rng=dropout_rng, dropout_rng=dropout_rng,
dropout_rate=self.config.attention_probs_dropout_prob, dropout_rate=self.config.attention_probs_dropout_prob,
...@@ -251,11 +250,10 @@ class FlaxElectraSelfAttention(nn.Module): ...@@ -251,11 +250,10 @@ class FlaxElectraSelfAttention(nn.Module):
precision=None, precision=None,
) )
outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
# TODO: at the moment it's not possible to retrieve attn_weights from
# dot_product_attention, but should be in the future -> add functionality then
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs return outputs
...@@ -302,7 +300,7 @@ class FlaxElectraAttention(nn.Module): ...@@ -302,7 +300,7 @@ class FlaxElectraAttention(nn.Module):
outputs = (hidden_states,) outputs = (hidden_states,)
if output_attentions: if output_attentions:
outputs += attn_outputs[1] outputs += (attn_outputs[1],)
return outputs return outputs
...@@ -399,7 +397,9 @@ class FlaxElectraLayerCollection(nn.Module): ...@@ -399,7 +397,9 @@ class FlaxElectraLayerCollection(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic) layer_outputs = layer(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
...@@ -534,11 +534,6 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): ...@@ -534,11 +534,6 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.return_dict return_dict = return_dict if return_dict is not None else self.config.return_dict
if output_attentions:
raise NotImplementedError(
"Currently attention scores cannot be returned. Please set `output_attentions` to False for now."
)
# init input tensors if not passed # init input tensors if not passed
if token_type_ids is None: if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids) token_type_ids = jnp.ones_like(input_ids)
......
...@@ -19,7 +19,8 @@ import flax.linen as nn ...@@ -19,7 +19,8 @@ import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.linen import combine_masks, dot_product_attention, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from jax import lax from jax import lax
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
...@@ -215,10 +216,9 @@ class FlaxGPT2Attention(nn.Module): ...@@ -215,10 +216,9 @@ class FlaxGPT2Attention(nn.Module):
) )
# usual dot product attention # usual dot product attention
attn_output = dot_product_attention( attn_weights = dot_product_attention_weights(
query, query,
key, key,
value,
bias=attention_bias, bias=attention_bias,
dropout_rng=dropout_rng, dropout_rng=dropout_rng,
dropout_rate=self.config.attn_pdrop, dropout_rate=self.config.attn_pdrop,
...@@ -227,14 +227,13 @@ class FlaxGPT2Attention(nn.Module): ...@@ -227,14 +227,13 @@ class FlaxGPT2Attention(nn.Module):
precision=None, precision=None,
) )
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
attn_output = self._merge_heads(attn_output) attn_output = self._merge_heads(attn_output)
attn_output = self.c_proj(attn_output) attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output, deterministic=deterministic) attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
# TODO: at the moment it's not possible to retrieve attn_weights from outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
# dot_product_attention, but should be in the future -> add functionality then return outputs
return (attn_output,)
class FlaxGPT2MLP(nn.Module): class FlaxGPT2MLP(nn.Module):
...@@ -447,7 +446,13 @@ class FlaxGPT2BlockCollection(nn.Module): ...@@ -447,7 +446,13 @@ class FlaxGPT2BlockCollection(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
layer_outputs = block(hidden_states, attention_mask, deterministic=deterministic, init_cache=init_cache) layer_outputs = block(
hidden_states,
attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: if output_attentions:
......
...@@ -20,7 +20,7 @@ import flax.linen as nn ...@@ -20,7 +20,7 @@ import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.linen import dot_product_attention from flax.linen.attention import dot_product_attention_weights
from jax import lax from jax import lax
from jax.random import PRNGKey from jax.random import PRNGKey
...@@ -227,10 +227,9 @@ class FlaxRobertaSelfAttention(nn.Module): ...@@ -227,10 +227,9 @@ class FlaxRobertaSelfAttention(nn.Module):
if not deterministic and self.config.attention_probs_dropout_prob > 0.0: if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
dropout_rng = self.make_rng("dropout") dropout_rng = self.make_rng("dropout")
attn_output = dot_product_attention( attn_weights = dot_product_attention_weights(
query_states, query_states,
key_states, key_states,
value_states,
bias=attention_bias, bias=attention_bias,
dropout_rng=dropout_rng, dropout_rng=dropout_rng,
dropout_rate=self.config.attention_probs_dropout_prob, dropout_rate=self.config.attention_probs_dropout_prob,
...@@ -240,11 +239,10 @@ class FlaxRobertaSelfAttention(nn.Module): ...@@ -240,11 +239,10 @@ class FlaxRobertaSelfAttention(nn.Module):
precision=None, precision=None,
) )
outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
# TODO: at the moment it's not possible to retrieve attn_weights from
# dot_product_attention, but should be in the future -> add functionality then
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs return outputs
...@@ -291,7 +289,7 @@ class FlaxRobertaAttention(nn.Module): ...@@ -291,7 +289,7 @@ class FlaxRobertaAttention(nn.Module):
outputs = (hidden_states,) outputs = (hidden_states,)
if output_attentions: if output_attentions:
outputs += attn_outputs[1] outputs += (attn_outputs[1],)
return outputs return outputs
...@@ -388,7 +386,9 @@ class FlaxRobertaLayerCollection(nn.Module): ...@@ -388,7 +386,9 @@ class FlaxRobertaLayerCollection(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic) layer_outputs = layer(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
...@@ -570,11 +570,6 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): ...@@ -570,11 +570,6 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.return_dict return_dict = return_dict if return_dict is not None else self.config.return_dict
if output_attentions:
raise NotImplementedError(
"Currently attention scores cannot be returned." "Please set `output_attentions` to False for now."
)
# init input tensors if not passed # init input tensors if not passed
if token_type_ids is None: if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids) token_type_ids = jnp.zeros_like(input_ids)
......
...@@ -79,8 +79,9 @@ class FlaxModelTesterMixin: ...@@ -79,8 +79,9 @@ class FlaxModelTesterMixin:
if "ForMultipleChoice" in model_class.__name__: if "ForMultipleChoice" in model_class.__name__:
inputs_dict = { inputs_dict = {
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1])) k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
for k, v in inputs_dict.items()
if isinstance(v, (jax_xla.DeviceArray, np.ndarray)) if isinstance(v, (jax_xla.DeviceArray, np.ndarray))
else v
for k, v in inputs_dict.items()
} }
return inputs_dict return inputs_dict
...@@ -310,3 +311,48 @@ class FlaxModelTesterMixin: ...@@ -310,3 +311,48 @@ class FlaxModelTesterMixin:
config.output_hidden_states = True config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class) check_hidden_states_output(inputs_dict, config, model_class)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
seq_length = getattr(self.model_tester, "seq_length", None)
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_length, seq_length],
)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_length, seq_length],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment