Unverified Commit e03966e4 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: XLA stable softmax (#16892)


Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 8246caf3
......@@ -19,6 +19,7 @@ from typing import List
import numpy as np
import tensorflow as tf
from .tf_utils import stable_softmax
from .utils import add_start_docstrings
from .utils.logging import get_logger
......@@ -166,7 +167,7 @@ class TFTopPLogitsWarper(TFLogitsWarper):
topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])
mask_scores = tf.fill(scores.shape, self.filter_value)
cumulative_probs = tf.math.cumsum(tf.nn.softmax(topk_scores, axis=-1), axis=-1)
cumulative_probs = tf.math.cumsum(stable_softmax(topk_scores, axis=-1), axis=-1)
score_mask = cumulative_probs < self.top_p
# Also include the token that is higher than top_p (the first false = shift and insert a True on the left)
......
......@@ -34,7 +34,7 @@ from .generation_tf_logits_process import (
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
from .tf_utils import shape_list
from .tf_utils import shape_list, stable_softmax
from .utils import ModelOutput, logging
......@@ -3060,7 +3060,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
logits, sorted_indices, axis=-1, batch_dims=1
) # expects logits to be of dim (batch_size, vocab_size)
cumulative_probs = tf.math.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
cumulative_probs = tf.math.cumsum(stable_softmax(sorted_logits, axis=-1), axis=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
......
......@@ -44,7 +44,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
......@@ -259,7 +259,7 @@ class TFAlbertAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......
......@@ -40,7 +40,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
add_code_sample_docstrings,
add_end_docstrings,
......@@ -244,7 +244,7 @@ class TFBartAttention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
......
......@@ -49,7 +49,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
DUMMY_INPUTS,
MULTIPLE_CHOICE_DUMMY_INPUTS,
......@@ -322,7 +322,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......
......@@ -40,7 +40,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
add_code_sample_docstrings,
add_end_docstrings,
......@@ -245,7 +245,7 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
......
......@@ -39,7 +39,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
add_code_sample_docstrings,
add_end_docstrings,
......@@ -245,7 +245,7 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
......
......@@ -34,7 +34,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
ModelOutput,
add_start_docstrings,
......@@ -333,7 +333,7 @@ class TFCLIPAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities.
_attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
_attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......
......@@ -42,7 +42,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings,
......@@ -228,7 +228,7 @@ class TFConvBertSelfAttention(tf.keras.layers.Layer):
conv_kernel_layer = self.conv_kernel_layer(conv_attn_layer)
conv_kernel_layer = tf.reshape(conv_kernel_layer, [-1, self.conv_kernel_size, 1])
conv_kernel_layer = tf.nn.softmax(conv_kernel_layer, axis=1)
conv_kernel_layer = stable_softmax(conv_kernel_layer, axis=1)
paddings = tf.constant(
[
......@@ -270,7 +270,7 @@ class TFConvBertSelfAttention(tf.keras.layers.Layer):
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
attention_probs = stable_softmax(attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......
......@@ -31,7 +31,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_ctrl import CTRLConfig
......@@ -79,7 +79,7 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
attention_mask = tf.cast(attention_mask, dtype=scaled_attention_logits.dtype)
scaled_attention_logits = scaled_attention_logits + attention_mask
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
attention_weights = stable_softmax(scaled_attention_logits, axis=-1)
# Mask heads if we want to
if head_mask is not None:
......
......@@ -39,7 +39,7 @@ from ...modeling_tf_utils import (
get_initializer,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_deberta import DebertaConfig
......@@ -96,7 +96,7 @@ class TFDebertaXSoftmax(tf.keras.layers.Layer):
rmask = tf.logical_not(tf.cast(mask, tf.bool))
output = tf.where(rmask, float("-inf"), inputs)
output = tf.nn.softmax(output, self.axis)
output = stable_softmax(output, self.axis)
output = tf.where(rmask, 0.0, output)
return output
......
......@@ -38,7 +38,7 @@ from ...modeling_tf_utils import (
get_initializer,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_deberta_v2 import DebertaV2Config
......@@ -97,7 +97,7 @@ class TFDebertaV2XSoftmax(tf.keras.layers.Layer):
rmask = tf.logical_not(tf.cast(mask, tf.bool))
output = tf.where(rmask, float("-inf"), inputs)
output = tf.nn.softmax(output, self.axis)
output = stable_softmax(output, self.axis)
output = tf.where(rmask, 0.0, output)
return output
......
......@@ -43,7 +43,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings,
......@@ -194,7 +194,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
mask = tf.cast(mask, dtype=scores.dtype)
scores = scores - 1e30 * (1.0 - mask)
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
# Mask heads if we want to
......
......@@ -44,7 +44,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
DUMMY_INPUTS,
MULTIPLE_CHOICE_DUMMY_INPUTS,
......@@ -171,7 +171,7 @@ class TFElectraSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......
......@@ -34,7 +34,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
......@@ -361,7 +361,7 @@ class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer):
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
mask = tf.cast(mask, dtype=scores.dtype)
scores = scores - 1e30 * (1.0 - mask)
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
# Mask heads if we want to
......
......@@ -42,7 +42,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
......@@ -530,7 +530,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
attn_score = attn_score - (INF * (1 - attention_mask[:, None, None]))
# attention probability
attn_prob = tf.nn.softmax(attn_score, axis=-1)
attn_prob = stable_softmax(attn_score, axis=-1)
attn_prob = self.attention_dropout(attn_prob, training=training)
# attention output, shape batch_size x seq_len x n_head x d_head
......
......@@ -40,7 +40,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
DUMMY_INPUTS,
ModelOutput,
......@@ -129,7 +129,7 @@ class TFAttention(tf.keras.layers.Layer):
attention_mask = tf.cast(attention_mask, dtype=w.dtype)
w = w + attention_mask
w = tf.nn.softmax(w, axis=-1)
w = stable_softmax(w, axis=-1)
w = self.attn_dropout(w, training=training)
# Mask heads if we want to
......
......@@ -43,7 +43,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import logging
from .configuration_gptj import GPTJConfig
......@@ -191,7 +191,7 @@ class TFGPTJAttention(tf.keras.layers.Layer):
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)
attn_weights = tf.cast(attn_weights, value.dtype)
attn_weights = self.attn_dropout(attn_weights)
......
......@@ -23,7 +23,7 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...tokenization_utils_base import BatchEncoding
from ...utils import (
ModelOutput,
......@@ -826,7 +826,7 @@ class TFHubertAttention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
......
......@@ -39,7 +39,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_layoutlm import LayoutLMConfig
......@@ -280,7 +280,7 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment