"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7f04373865393f625fb8f20bdabdab188120f9b8"
Unverified Commit e03966e4 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: XLA stable softmax (#16892)


Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 8246caf3
...@@ -19,6 +19,7 @@ from typing import List ...@@ -19,6 +19,7 @@ from typing import List
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from .tf_utils import stable_softmax
from .utils import add_start_docstrings from .utils import add_start_docstrings
from .utils.logging import get_logger from .utils.logging import get_logger
...@@ -166,7 +167,7 @@ class TFTopPLogitsWarper(TFLogitsWarper): ...@@ -166,7 +167,7 @@ class TFTopPLogitsWarper(TFLogitsWarper):
topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1]) topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])
mask_scores = tf.fill(scores.shape, self.filter_value) mask_scores = tf.fill(scores.shape, self.filter_value)
cumulative_probs = tf.math.cumsum(tf.nn.softmax(topk_scores, axis=-1), axis=-1) cumulative_probs = tf.math.cumsum(stable_softmax(topk_scores, axis=-1), axis=-1)
score_mask = cumulative_probs < self.top_p score_mask = cumulative_probs < self.top_p
# Also include the token that is higher than top_p (the first false = shift and insert a True on the left) # Also include the token that is higher than top_p (the first false = shift and insert a True on the left)
......
...@@ -34,7 +34,7 @@ from .generation_tf_logits_process import ( ...@@ -34,7 +34,7 @@ from .generation_tf_logits_process import (
TFTopKLogitsWarper, TFTopKLogitsWarper,
TFTopPLogitsWarper, TFTopPLogitsWarper,
) )
from .tf_utils import shape_list from .tf_utils import shape_list, stable_softmax
from .utils import ModelOutput, logging from .utils import ModelOutput, logging
...@@ -3060,7 +3060,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In ...@@ -3060,7 +3060,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
logits, sorted_indices, axis=-1, batch_dims=1 logits, sorted_indices, axis=-1, batch_dims=1
) # expects logits to be of dim (batch_size, vocab_size) ) # expects logits to be of dim (batch_size, vocab_size)
cumulative_probs = tf.math.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1) cumulative_probs = tf.math.cumsum(stable_softmax(sorted_logits, axis=-1), axis=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept) # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove = cumulative_probs > top_p
......
...@@ -44,7 +44,7 @@ from ...modeling_tf_utils import ( ...@@ -44,7 +44,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput, ModelOutput,
...@@ -259,7 +259,7 @@ class TFAlbertAttention(tf.keras.layers.Layer): ...@@ -259,7 +259,7 @@ class TFAlbertAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask) attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1) attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
......
...@@ -40,7 +40,7 @@ from ...modeling_tf_utils import ( ...@@ -40,7 +40,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_end_docstrings, add_end_docstrings,
...@@ -244,7 +244,7 @@ class TFBartAttention(tf.keras.layers.Layer): ...@@ -244,7 +244,7 @@ class TFBartAttention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
......
...@@ -49,7 +49,7 @@ from ...modeling_tf_utils import ( ...@@ -49,7 +49,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
...@@ -322,7 +322,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -322,7 +322,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask) attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1) attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
......
...@@ -40,7 +40,7 @@ from ...modeling_tf_utils import ( ...@@ -40,7 +40,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_end_docstrings, add_end_docstrings,
...@@ -245,7 +245,7 @@ class TFBlenderbotAttention(tf.keras.layers.Layer): ...@@ -245,7 +245,7 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
......
...@@ -39,7 +39,7 @@ from ...modeling_tf_utils import ( ...@@ -39,7 +39,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_end_docstrings, add_end_docstrings,
...@@ -245,7 +245,7 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer): ...@@ -245,7 +245,7 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
......
...@@ -34,7 +34,7 @@ from ...modeling_tf_utils import ( ...@@ -34,7 +34,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -333,7 +333,7 @@ class TFCLIPAttention(tf.keras.layers.Layer): ...@@ -333,7 +333,7 @@ class TFCLIPAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask) attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
_attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1) _attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
......
...@@ -42,7 +42,7 @@ from ...modeling_tf_utils import ( ...@@ -42,7 +42,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -228,7 +228,7 @@ class TFConvBertSelfAttention(tf.keras.layers.Layer): ...@@ -228,7 +228,7 @@ class TFConvBertSelfAttention(tf.keras.layers.Layer):
conv_kernel_layer = self.conv_kernel_layer(conv_attn_layer) conv_kernel_layer = self.conv_kernel_layer(conv_attn_layer)
conv_kernel_layer = tf.reshape(conv_kernel_layer, [-1, self.conv_kernel_size, 1]) conv_kernel_layer = tf.reshape(conv_kernel_layer, [-1, self.conv_kernel_size, 1])
conv_kernel_layer = tf.nn.softmax(conv_kernel_layer, axis=1) conv_kernel_layer = stable_softmax(conv_kernel_layer, axis=1)
paddings = tf.constant( paddings = tf.constant(
[ [
...@@ -270,7 +270,7 @@ class TFConvBertSelfAttention(tf.keras.layers.Layer): ...@@ -270,7 +270,7 @@ class TFConvBertSelfAttention(tf.keras.layers.Layer):
attention_scores = attention_scores + attention_mask attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(attention_scores, axis=-1) attention_probs = stable_softmax(attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
......
...@@ -31,7 +31,7 @@ from ...modeling_tf_utils import ( ...@@ -31,7 +31,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_ctrl import CTRLConfig from .configuration_ctrl import CTRLConfig
...@@ -79,7 +79,7 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N ...@@ -79,7 +79,7 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
attention_mask = tf.cast(attention_mask, dtype=scaled_attention_logits.dtype) attention_mask = tf.cast(attention_mask, dtype=scaled_attention_logits.dtype)
scaled_attention_logits = scaled_attention_logits + attention_mask scaled_attention_logits = scaled_attention_logits + attention_mask
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) attention_weights = stable_softmax(scaled_attention_logits, axis=-1)
# Mask heads if we want to # Mask heads if we want to
if head_mask is not None: if head_mask is not None:
......
...@@ -39,7 +39,7 @@ from ...modeling_tf_utils import ( ...@@ -39,7 +39,7 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_deberta import DebertaConfig from .configuration_deberta import DebertaConfig
...@@ -96,7 +96,7 @@ class TFDebertaXSoftmax(tf.keras.layers.Layer): ...@@ -96,7 +96,7 @@ class TFDebertaXSoftmax(tf.keras.layers.Layer):
rmask = tf.logical_not(tf.cast(mask, tf.bool)) rmask = tf.logical_not(tf.cast(mask, tf.bool))
output = tf.where(rmask, float("-inf"), inputs) output = tf.where(rmask, float("-inf"), inputs)
output = tf.nn.softmax(output, self.axis) output = stable_softmax(output, self.axis)
output = tf.where(rmask, 0.0, output) output = tf.where(rmask, 0.0, output)
return output return output
......
...@@ -38,7 +38,7 @@ from ...modeling_tf_utils import ( ...@@ -38,7 +38,7 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_deberta_v2 import DebertaV2Config from .configuration_deberta_v2 import DebertaV2Config
...@@ -97,7 +97,7 @@ class TFDebertaV2XSoftmax(tf.keras.layers.Layer): ...@@ -97,7 +97,7 @@ class TFDebertaV2XSoftmax(tf.keras.layers.Layer):
rmask = tf.logical_not(tf.cast(mask, tf.bool)) rmask = tf.logical_not(tf.cast(mask, tf.bool))
output = tf.where(rmask, float("-inf"), inputs) output = tf.where(rmask, float("-inf"), inputs)
output = tf.nn.softmax(output, self.axis) output = stable_softmax(output, self.axis)
output = tf.where(rmask, 0.0, output) output = tf.where(rmask, 0.0, output)
return output return output
......
...@@ -43,7 +43,7 @@ from ...modeling_tf_utils import ( ...@@ -43,7 +43,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -194,7 +194,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): ...@@ -194,7 +194,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
mask = tf.cast(mask, dtype=scores.dtype) mask = tf.cast(mask, dtype=scores.dtype)
scores = scores - 1e30 * (1.0 - mask) scores = scores - 1e30 * (1.0 - mask)
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen) weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen) weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
# Mask heads if we want to # Mask heads if we want to
......
...@@ -44,7 +44,7 @@ from ...modeling_tf_utils import ( ...@@ -44,7 +44,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
...@@ -171,7 +171,7 @@ class TFElectraSelfAttention(tf.keras.layers.Layer): ...@@ -171,7 +171,7 @@ class TFElectraSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask) attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1) attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
......
...@@ -34,7 +34,7 @@ from ...modeling_tf_utils import ( ...@@ -34,7 +34,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -361,7 +361,7 @@ class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer): ...@@ -361,7 +361,7 @@ class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer):
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen) # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
mask = tf.cast(mask, dtype=scores.dtype) mask = tf.cast(mask, dtype=scores.dtype)
scores = scores - 1e30 * (1.0 - mask) scores = scores - 1e30 * (1.0 - mask)
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen) weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen) weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
# Mask heads if we want to # Mask heads if we want to
......
...@@ -42,7 +42,7 @@ from ...modeling_tf_utils import ( ...@@ -42,7 +42,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput, ModelOutput,
...@@ -530,7 +530,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer): ...@@ -530,7 +530,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
attn_score = attn_score - (INF * (1 - attention_mask[:, None, None])) attn_score = attn_score - (INF * (1 - attention_mask[:, None, None]))
# attention probability # attention probability
attn_prob = tf.nn.softmax(attn_score, axis=-1) attn_prob = stable_softmax(attn_score, axis=-1)
attn_prob = self.attention_dropout(attn_prob, training=training) attn_prob = self.attention_dropout(attn_prob, training=training)
# attention output, shape batch_size x seq_len x n_head x d_head # attention output, shape batch_size x seq_len x n_head x d_head
......
...@@ -40,7 +40,7 @@ from ...modeling_tf_utils import ( ...@@ -40,7 +40,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
ModelOutput, ModelOutput,
...@@ -129,7 +129,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -129,7 +129,7 @@ class TFAttention(tf.keras.layers.Layer):
attention_mask = tf.cast(attention_mask, dtype=w.dtype) attention_mask = tf.cast(attention_mask, dtype=w.dtype)
w = w + attention_mask w = w + attention_mask
w = tf.nn.softmax(w, axis=-1) w = stable_softmax(w, axis=-1)
w = self.attn_dropout(w, training=training) w = self.attn_dropout(w, training=training)
# Mask heads if we want to # Mask heads if we want to
......
...@@ -43,7 +43,7 @@ from ...modeling_tf_utils import ( ...@@ -43,7 +43,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import logging from ...utils import logging
from .configuration_gptj import GPTJConfig from .configuration_gptj import GPTJConfig
...@@ -191,7 +191,7 @@ class TFGPTJAttention(tf.keras.layers.Layer): ...@@ -191,7 +191,7 @@ class TFGPTJAttention(tf.keras.layers.Layer):
# Apply the attention mask # Apply the attention mask
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
attn_weights = tf.cast(attn_weights, value.dtype) attn_weights = tf.cast(attn_weights, value.dtype)
attn_weights = self.attn_dropout(attn_weights) attn_weights = self.attn_dropout(attn_weights)
......
...@@ -23,7 +23,7 @@ import tensorflow as tf ...@@ -23,7 +23,7 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_base import BatchEncoding
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
...@@ -826,7 +826,7 @@ class TFHubertAttention(tf.keras.layers.Layer): ...@@ -826,7 +826,7 @@ class TFHubertAttention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
......
...@@ -39,7 +39,7 @@ from ...modeling_tf_utils import ( ...@@ -39,7 +39,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_layoutlm import LayoutLMConfig from .configuration_layoutlm import LayoutLMConfig
...@@ -280,7 +280,7 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer): ...@@ -280,7 +280,7 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.add(attention_scores, attention_mask) attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1) attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment