Unverified Commit e03966e4 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: XLA stable softmax (#16892)


Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 8246caf3
......@@ -33,7 +33,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
......@@ -271,7 +271,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
),
lambda: attn_scores,
)
attn_probs = tf.nn.softmax(attn_scores, axis=-1)
attn_probs = stable_softmax(attn_scores, axis=-1)
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
# Make sure to create a mask with the proper shape:
......@@ -886,7 +886,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
)
# compute global attn probs
global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1)
global_attn_probs_float = stable_softmax(global_attn_scores, axis=-1)
# apply layer head masking
if layer_head_mask is not None:
......@@ -1085,7 +1085,7 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
)
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
if tf.executing_eagerly():
......
......@@ -34,7 +34,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
......@@ -800,7 +800,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
),
lambda: attn_scores,
)
attn_probs = tf.nn.softmax(attn_scores, axis=-1)
attn_probs = stable_softmax(attn_scores, axis=-1)
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
# Make sure to create a mask with the proper shape:
......@@ -1415,7 +1415,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
# compute global attn probs
global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1)
global_attn_probs_float = stable_softmax(global_attn_scores, axis=-1)
# apply layer head masking
if layer_head_mask is not None:
......
......@@ -22,6 +22,8 @@ from typing import Dict, Optional, Tuple
import tensorflow as tf
from transformers.tf_utils import stable_softmax
from ...activations_tf import get_tf_activation
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list, unpack_inputs
from ...utils import (
......@@ -302,7 +304,7 @@ class TFLxmertAttention(tf.keras.layers.Layer):
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
attention_probs = stable_softmax(attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......
......@@ -39,7 +39,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
add_code_sample_docstrings,
add_end_docstrings,
......@@ -284,7 +284,7 @@ class TFMarianAttention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
......
......@@ -39,7 +39,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
add_code_sample_docstrings,
add_end_docstrings,
......@@ -246,7 +246,7 @@ class TFMBartAttention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
......
......@@ -46,7 +46,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
......@@ -278,7 +278,7 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer):
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
attention_probs = stable_softmax(attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......
......@@ -42,7 +42,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings,
......@@ -241,7 +241,7 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
attention_probs = stable_softmax(attention_scores, axis=-1)
attention_probs = self.dropout(attention_probs, training=training)
......
......@@ -35,7 +35,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
......@@ -111,7 +111,7 @@ class TFAttention(tf.keras.layers.Layer):
attention_mask = tf.cast(attention_mask, dtype=w.dtype)
w = w + attention_mask
w = tf.nn.softmax(w, axis=-1)
w = stable_softmax(w, axis=-1)
w = self.attn_dropout(w, training=training)
# Mask heads if we want to
......
......@@ -39,7 +39,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
add_code_sample_docstrings,
add_end_docstrings,
......@@ -285,7 +285,7 @@ class TFPegasusAttention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
......
......@@ -45,7 +45,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
DUMMY_INPUTS,
MULTIPLE_CHOICE_DUMMY_INPUTS,
......@@ -241,7 +241,7 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......
......@@ -46,7 +46,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
DUMMY_INPUTS,
MULTIPLE_CHOICE_DUMMY_INPUTS,
......@@ -290,7 +290,7 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......
......@@ -46,7 +46,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings,
......@@ -262,7 +262,7 @@ class TFRoFormerSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......
......@@ -36,7 +36,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -348,7 +348,7 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
......
......@@ -41,7 +41,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
DUMMY_INPUTS,
DUMMY_MASK,
......@@ -398,7 +398,7 @@ class TFT5Attention(tf.keras.layers.Layer):
position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length)
scores += position_bias
weights = tf.nn.softmax(scores, axis=-1) # (batch_size, n_heads, query_length, key_length)
weights = stable_softmax(scores, axis=-1) # (batch_size, n_heads, query_length, key_length)
weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length)
# Mask heads if we want to
......
......@@ -38,7 +38,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
ModelOutput,
add_start_docstrings,
......@@ -346,7 +346,7 @@ class TFTapasSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......@@ -2216,7 +2216,7 @@ def _calculate_expected_result(
aggregation_op_only_probs = gumbel_dist.sample()
else:
# <float32>[batch_size, num_aggregation_labels - 1]
aggregation_op_only_probs = tf.nn.softmax(logits_aggregation[:, 1:] / config.aggregation_temperature, axis=-1)
aggregation_op_only_probs = stable_softmax(logits_aggregation[:, 1:] / config.aggregation_temperature, axis=-1)
all_results = tf.concat(
[
tf.expand_dims(sum_result, axis=1),
......
......@@ -31,7 +31,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
......@@ -236,7 +236,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
attn_score = attn_score * (1.0 - attn_mask_t) - 1e30 * attn_mask_t
# [qlen x klen x bsz x n_head]
attn_prob = tf.nn.softmax(attn_score, axis=1)
attn_prob = stable_softmax(attn_score, axis=1)
attn_prob = self.dropatt(attn_prob, training=training)
# Mask heads if we want to
......
......@@ -32,7 +32,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_vit import ViTConfig
......@@ -260,7 +260,7 @@ class TFViTSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.divide(attention_scores, dk)
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......
......@@ -38,7 +38,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import logging
from .configuration_vit_mae import ViTMAEConfig
......@@ -407,7 +407,7 @@ class TFViTMAESelfAttention(tf.keras.layers.Layer):
attention_scores = tf.divide(attention_scores, dk)
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......
......@@ -25,7 +25,7 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...tokenization_utils_base import BatchEncoding
from ...utils import (
ModelOutput,
......@@ -855,7 +855,7 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
......
......@@ -44,7 +44,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
......@@ -187,7 +187,7 @@ class TFXLMMultiHeadAttention(tf.keras.layers.Layer):
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
mask = tf.cast(mask, dtype=scores.dtype)
scores = scores - 1e30 * (1.0 - mask)
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
# Mask heads if we want to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment