Unverified Commit e03966e4 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: XLA stable softmax (#16892)


Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 8246caf3
......@@ -39,7 +39,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
......@@ -159,7 +159,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
attn_score = attn_score - 1e30 * attn_mask
# attention probability
attn_prob = tf.nn.softmax(attn_score, axis=1)
attn_prob = stable_softmax(attn_score, axis=1)
attn_prob = self.dropout(attn_prob, training=training)
......
......@@ -9,6 +9,8 @@ from .base import PIPELINE_INIT_ARGS, GenericTensor, Pipeline, PipelineException
if is_tf_available():
import tensorflow as tf
from ..tf_utils import stable_softmax
if is_torch_available():
import torch
......@@ -101,7 +103,7 @@ class FillMaskPipeline(Pipeline):
outputs = outputs.numpy()
logits = outputs[0, masked_index, :]
probs = tf.nn.softmax(logits, axis=-1)
probs = stable_softmax(logits, axis=-1)
if target_ids is not None:
probs = tf.gather_nd(tf.squeeze(probs, 0), target_ids.reshape(-1, 1))
probs = tf.expand_dims(probs, 0)
......
......@@ -20,6 +20,7 @@ if is_tf_available():
import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
from ..tf_utils import stable_softmax
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
......@@ -103,7 +104,7 @@ class ImageClassificationPipeline(Pipeline):
probs = model_outputs.logits.softmax(-1)[0]
scores, ids = probs.topk(top_k)
elif self.framework == "tf":
probs = tf.nn.softmax(model_outputs.logits, axis=-1)[0]
probs = stable_softmax(model_outputs.logits, axis=-1)[0]
topk = tf.math.top_k(probs, k=top_k)
scores, ids = topk.values.numpy(), topk.indices.numpy()
else:
......
......@@ -22,6 +22,8 @@ if is_torch_available():
if is_tf_available():
import tensorflow as tf
from ..tf_utils import stable_softmax
logger = logging.get_logger(__name__)
......@@ -119,7 +121,7 @@ class ZeroShotImageClassificationPipeline(ChunkPipeline):
scores = probs.tolist()
else:
logits = tf.concat([output["logits_per_image"] for output in model_outputs], axis=0)
probs = tf.nn.softmax(logits, axis=0)
probs = stable_softmax(logits, axis=0)
scores = probs.numpy().tolist()
result = [
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
from typing import List, Optional, Union
import numpy as np
import tensorflow as tf
......@@ -44,3 +44,27 @@ def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
static = tensor.shape.as_list()
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional[str] = None) -> tf.Tensor:
"""
Stable wrapper that returns the same output as `tf.nn.softmax`, but that works reliably with XLA on CPU. It is
meant as a workaround for the [following issue](https://github.com/tensorflow/tensorflow/issues/55682), and will be
removed after it gets fixed. The arguments and outputs are the same as `tf.nn.softmax`, and relies on the fact that
`softmax(x) = softmax(x + c)` (see https://ogunlao.github.io/2020/04/26/you_dont_really_know_softmax.html).
Args:
logits (`tf.Tensor`):
Must be one of the following types: half, float32, float64.
axis (`int`, *optional*):
The dimension softmax would be performed on. The default is -1 which indicates the last dimension.
name (`str`, *optional*):
A name for the operation.
Returns:
`tf.Tensor`:
A Tensor. Has the same type and shape as logits.
"""
# TODO: When the issue linked above gets sorted, add a check on TF version here and use the original function if
# it has the fix. After we drop the support for unfixed versions, remove this function.
return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
......@@ -53,7 +53,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
......@@ -244,7 +244,7 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......@@ -1665,8 +1665,8 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings,
keras_serializable,
unpack_inputs,
); from ...tf_utils import (shape_list,
)
from ...tf_utils import shape_list, stable_softmax
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
......@@ -1855,7 +1855,7 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
......
......@@ -16,7 +16,7 @@
import unittest
from transformers import GPT2Config, is_tf_available
from transformers.testing_utils import get_gpu_count, require_tf, slow
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
......@@ -536,8 +536,6 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
self.assertListEqual(output_strings, expected_output_string)
@slow
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def test_lm_generate_gpt2_greedy_xla(self):
# TODO (Joao): convert this to an example with a batch size>1 with different input lengths that works (and fix
# the underlying problem)
......@@ -563,30 +561,33 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
self.assertListEqual(output_strings, expected_output_strings)
@slow
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def test_lm_generate_gpt2_sample_xla(self):
# NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
# output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible
# and that we can seed both versions.
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
sentence = ["The dog"]
expected_output_string = [
"The dog must be well educated to do anything. If anything, this must be her best friend"
]
expected_output_string_xla = ["The dog has been named in connection with the murder of a 20-year-old man in!"]
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string)
xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string_xla)
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
sentence = ["The dog"]
expected_output_string = [
"The dog owner asked why did our vet decide there needed to be extra ventilation inside because most puppies"
]
expected_output_string_xla = [
"The dog has been named in connection with the murder of a 20-year-old man in!"
]
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string)
xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string_xla)
......@@ -16,7 +16,7 @@
import unittest
from transformers import T5Config, is_tf_available
from transformers.testing_utils import get_gpu_count, require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.utils import cached_property
from ..test_configuration_common import ConfigTester
......@@ -481,8 +481,6 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
@require_tokenizers
class TFT5GenerationIntegrationTests(unittest.TestCase):
@slow
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def test_greedy_xla_generate_simple(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
......@@ -534,30 +532,31 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(expected_output_string, output_strings)
@slow
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def test_sample_xla_generate_simple(self):
# NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
# output out of the same seed is far from guaranteed (unlike this example). We can, however, confirm that the
# results are sensible and that we can seed both versions.
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
sentence = "Translate English to German: I have two bananas"
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
expected_output_string = ["Ich habe 2 Bananen"]
expected_output_string_xla = ["Ich habe 2 Bananen"]
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(expected_output_string, output_strings)
# output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible
# and that we can seed both versions.
xla_generate = tf.function(model.generate, jit_compile=True)
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
self.assertListEqual(expected_output_string_xla, output_strings_xla)
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
sentence = "Translate English to German: I have two bananas"
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
expected_output_string = ["Ich habe zwei Bananen"]
expected_output_string_xla = ["Ich habe 2 Bananen"]
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(expected_output_string, output_strings)
xla_generate = tf.function(model.generate, jit_compile=True)
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
self.assertListEqual(expected_output_string_xla, output_strings_xla)
@slow
def test_sample_generate(self):
......
......@@ -84,6 +84,7 @@ if is_tf_available():
TFSampleEncoderDecoderOutput,
)
from transformers.modeling_tf_utils import unpack_inputs
from transformers.tf_utils import stable_softmax
if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU")
......@@ -1709,6 +1710,41 @@ class UtilsFunctionsTest(unittest.TestCase):
self.assertFalse(output[3])
self.assertFalse(output[4])
# Tests whether the stable softmax is stable on CPU, with and without XLA
def test_xla_stable_softmax(self):
large_penalty = -1e9
n_tokens = 10
batch_size = 8
def masked_softmax(x, boolean_mask):
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
masked_x = x + numerical_mask
return stable_softmax(masked_x)
xla_masked_softmax = tf.function(masked_softmax, jit_compile=True)
xla_stable_softmax = tf.function(stable_softmax, jit_compile=True)
x = tf.random.normal((batch_size, n_tokens))
# Same outcome regardless of the boolean mask here
masked_tokens = random.randint(0, n_tokens)
boolean_mask = tf.convert_to_tensor([[1] * (n_tokens - masked_tokens) + [0] * masked_tokens], dtype=tf.int32)
# We can randomly mask a random numerical input OUTSIDE XLA
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
masked_x = x + numerical_mask
xla_out = xla_stable_softmax(masked_x)
out = stable_softmax(masked_x)
assert tf.experimental.numpy.allclose(xla_out, out)
# The stable softmax has the same output as the original softmax
unstable_out = tf.nn.softmax(masked_x)
assert tf.experimental.numpy.allclose(unstable_out, out)
# We can randomly mask a random numerical input INSIDE XLA
xla_out = xla_masked_softmax(x, boolean_mask)
out = masked_softmax(x, boolean_mask)
assert tf.experimental.numpy.allclose(xla_out, out)
@require_tf
@is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment