"tests/vscode:/vscode.git/clone" did not exist on "05f2290114def72a3e20643fd4359c4c2d3abafe"
Unverified Commit e03966e4 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: XLA stable softmax (#16892)


Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 8246caf3
...@@ -33,7 +33,7 @@ from ...modeling_tf_utils import ( ...@@ -33,7 +33,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -271,7 +271,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -271,7 +271,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
), ),
lambda: attn_scores, lambda: attn_scores,
) )
attn_probs = tf.nn.softmax(attn_scores, axis=-1) attn_probs = stable_softmax(attn_scores, axis=-1)
# softmax sometimes inserts NaN if all positions are masked, replace them with 0 # softmax sometimes inserts NaN if all positions are masked, replace them with 0
# Make sure to create a mask with the proper shape: # Make sure to create a mask with the proper shape:
...@@ -886,7 +886,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -886,7 +886,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
) )
# compute global attn probs # compute global attn probs
global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1) global_attn_probs_float = stable_softmax(global_attn_scores, axis=-1)
# apply layer head masking # apply layer head masking
if layer_head_mask is not None: if layer_head_mask is not None:
...@@ -1085,7 +1085,7 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer): ...@@ -1085,7 +1085,7 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
) )
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
if tf.executing_eagerly(): if tf.executing_eagerly():
......
...@@ -34,7 +34,7 @@ from ...modeling_tf_utils import ( ...@@ -34,7 +34,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput, ModelOutput,
...@@ -800,7 +800,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -800,7 +800,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
), ),
lambda: attn_scores, lambda: attn_scores,
) )
attn_probs = tf.nn.softmax(attn_scores, axis=-1) attn_probs = stable_softmax(attn_scores, axis=-1)
# softmax sometimes inserts NaN if all positions are masked, replace them with 0 # softmax sometimes inserts NaN if all positions are masked, replace them with 0
# Make sure to create a mask with the proper shape: # Make sure to create a mask with the proper shape:
...@@ -1415,7 +1415,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1415,7 +1415,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
# compute global attn probs # compute global attn probs
global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1) global_attn_probs_float = stable_softmax(global_attn_scores, axis=-1)
# apply layer head masking # apply layer head masking
if layer_head_mask is not None: if layer_head_mask is not None:
......
...@@ -22,6 +22,8 @@ from typing import Dict, Optional, Tuple ...@@ -22,6 +22,8 @@ from typing import Dict, Optional, Tuple
import tensorflow as tf import tensorflow as tf
from transformers.tf_utils import stable_softmax
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list, unpack_inputs from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list, unpack_inputs
from ...utils import ( from ...utils import (
...@@ -302,7 +304,7 @@ class TFLxmertAttention(tf.keras.layers.Layer): ...@@ -302,7 +304,7 @@ class TFLxmertAttention(tf.keras.layers.Layer):
attention_scores = attention_scores + attention_mask attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(attention_scores, axis=-1) attention_probs = stable_softmax(attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
......
...@@ -39,7 +39,7 @@ from ...modeling_tf_utils import ( ...@@ -39,7 +39,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_end_docstrings, add_end_docstrings,
...@@ -284,7 +284,7 @@ class TFMarianAttention(tf.keras.layers.Layer): ...@@ -284,7 +284,7 @@ class TFMarianAttention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
......
...@@ -39,7 +39,7 @@ from ...modeling_tf_utils import ( ...@@ -39,7 +39,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_end_docstrings, add_end_docstrings,
...@@ -246,7 +246,7 @@ class TFMBartAttention(tf.keras.layers.Layer): ...@@ -246,7 +246,7 @@ class TFMBartAttention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
......
...@@ -46,7 +46,7 @@ from ...modeling_tf_utils import ( ...@@ -46,7 +46,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput, ModelOutput,
...@@ -278,7 +278,7 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer): ...@@ -278,7 +278,7 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer):
attention_scores = attention_scores + attention_mask attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(attention_scores, axis=-1) attention_probs = stable_softmax(attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
......
...@@ -42,7 +42,7 @@ from ...modeling_tf_utils import ( ...@@ -42,7 +42,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -241,7 +241,7 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer): ...@@ -241,7 +241,7 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
if attention_mask is not None: if attention_mask is not None:
attention_scores = attention_scores + attention_mask attention_scores = attention_scores + attention_mask
attention_probs = tf.nn.softmax(attention_scores, axis=-1) attention_probs = stable_softmax(attention_scores, axis=-1)
attention_probs = self.dropout(attention_probs, training=training) attention_probs = self.dropout(attention_probs, training=training)
......
...@@ -35,7 +35,7 @@ from ...modeling_tf_utils import ( ...@@ -35,7 +35,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -111,7 +111,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -111,7 +111,7 @@ class TFAttention(tf.keras.layers.Layer):
attention_mask = tf.cast(attention_mask, dtype=w.dtype) attention_mask = tf.cast(attention_mask, dtype=w.dtype)
w = w + attention_mask w = w + attention_mask
w = tf.nn.softmax(w, axis=-1) w = stable_softmax(w, axis=-1)
w = self.attn_dropout(w, training=training) w = self.attn_dropout(w, training=training)
# Mask heads if we want to # Mask heads if we want to
......
...@@ -39,7 +39,7 @@ from ...modeling_tf_utils import ( ...@@ -39,7 +39,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_end_docstrings, add_end_docstrings,
...@@ -285,7 +285,7 @@ class TFPegasusAttention(tf.keras.layers.Layer): ...@@ -285,7 +285,7 @@ class TFPegasusAttention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
......
...@@ -45,7 +45,7 @@ from ...modeling_tf_utils import ( ...@@ -45,7 +45,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
...@@ -241,7 +241,7 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer): ...@@ -241,7 +241,7 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask) attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1) attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
......
...@@ -46,7 +46,7 @@ from ...modeling_tf_utils import ( ...@@ -46,7 +46,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
...@@ -290,7 +290,7 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer): ...@@ -290,7 +290,7 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask) attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1) attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
......
...@@ -46,7 +46,7 @@ from ...modeling_tf_utils import ( ...@@ -46,7 +46,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -262,7 +262,7 @@ class TFRoFormerSelfAttention(tf.keras.layers.Layer): ...@@ -262,7 +262,7 @@ class TFRoFormerSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask) attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1) attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
......
...@@ -36,7 +36,7 @@ from ...modeling_tf_utils import ( ...@@ -36,7 +36,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -348,7 +348,7 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer): ...@@ -348,7 +348,7 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
......
...@@ -41,7 +41,7 @@ from ...modeling_tf_utils import ( ...@@ -41,7 +41,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
DUMMY_MASK, DUMMY_MASK,
...@@ -398,7 +398,7 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -398,7 +398,7 @@ class TFT5Attention(tf.keras.layers.Layer):
position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length) position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length)
scores += position_bias scores += position_bias
weights = tf.nn.softmax(scores, axis=-1) # (batch_size, n_heads, query_length, key_length) weights = stable_softmax(scores, axis=-1) # (batch_size, n_heads, query_length, key_length)
weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length) weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length)
# Mask heads if we want to # Mask heads if we want to
......
...@@ -38,7 +38,7 @@ from ...modeling_tf_utils import ( ...@@ -38,7 +38,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -346,7 +346,7 @@ class TFTapasSelfAttention(tf.keras.layers.Layer): ...@@ -346,7 +346,7 @@ class TFTapasSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask) attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1) attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
...@@ -2216,7 +2216,7 @@ def _calculate_expected_result( ...@@ -2216,7 +2216,7 @@ def _calculate_expected_result(
aggregation_op_only_probs = gumbel_dist.sample() aggregation_op_only_probs = gumbel_dist.sample()
else: else:
# <float32>[batch_size, num_aggregation_labels - 1] # <float32>[batch_size, num_aggregation_labels - 1]
aggregation_op_only_probs = tf.nn.softmax(logits_aggregation[:, 1:] / config.aggregation_temperature, axis=-1) aggregation_op_only_probs = stable_softmax(logits_aggregation[:, 1:] / config.aggregation_temperature, axis=-1)
all_results = tf.concat( all_results = tf.concat(
[ [
tf.expand_dims(sum_result, axis=1), tf.expand_dims(sum_result, axis=1),
......
...@@ -31,7 +31,7 @@ from ...modeling_tf_utils import ( ...@@ -31,7 +31,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -236,7 +236,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer): ...@@ -236,7 +236,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
attn_score = attn_score * (1.0 - attn_mask_t) - 1e30 * attn_mask_t attn_score = attn_score * (1.0 - attn_mask_t) - 1e30 * attn_mask_t
# [qlen x klen x bsz x n_head] # [qlen x klen x bsz x n_head]
attn_prob = tf.nn.softmax(attn_score, axis=1) attn_prob = stable_softmax(attn_score, axis=1)
attn_prob = self.dropatt(attn_prob, training=training) attn_prob = self.dropatt(attn_prob, training=training)
# Mask heads if we want to # Mask heads if we want to
......
...@@ -32,7 +32,7 @@ from ...modeling_tf_utils import ( ...@@ -32,7 +32,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_vit import ViTConfig from .configuration_vit import ViTConfig
...@@ -260,7 +260,7 @@ class TFViTSelfAttention(tf.keras.layers.Layer): ...@@ -260,7 +260,7 @@ class TFViTSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.divide(attention_scores, dk) attention_scores = tf.divide(attention_scores, dk)
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1) attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
......
...@@ -38,7 +38,7 @@ from ...modeling_tf_utils import ( ...@@ -38,7 +38,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import logging from ...utils import logging
from .configuration_vit_mae import ViTMAEConfig from .configuration_vit_mae import ViTMAEConfig
...@@ -407,7 +407,7 @@ class TFViTMAESelfAttention(tf.keras.layers.Layer): ...@@ -407,7 +407,7 @@ class TFViTMAESelfAttention(tf.keras.layers.Layer):
attention_scores = tf.divide(attention_scores, dk) attention_scores = tf.divide(attention_scores, dk)
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1) attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
......
...@@ -25,7 +25,7 @@ import tensorflow as tf ...@@ -25,7 +25,7 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_base import BatchEncoding
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
...@@ -855,7 +855,7 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer): ...@@ -855,7 +855,7 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
......
...@@ -44,7 +44,7 @@ from ...modeling_tf_utils import ( ...@@ -44,7 +44,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput, ModelOutput,
...@@ -187,7 +187,7 @@ class TFXLMMultiHeadAttention(tf.keras.layers.Layer): ...@@ -187,7 +187,7 @@ class TFXLMMultiHeadAttention(tf.keras.layers.Layer):
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen) # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
mask = tf.cast(mask, dtype=scores.dtype) mask = tf.cast(mask, dtype=scores.dtype)
scores = scores - 1e30 * (1.0 - mask) scores = scores - 1e30 * (1.0 - mask)
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen) weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen) weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
# Mask heads if we want to # Mask heads if we want to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment