Unverified Commit e03966e4 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: XLA stable softmax (#16892)


Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 8246caf3
...@@ -39,7 +39,7 @@ from ...modeling_tf_utils import ( ...@@ -39,7 +39,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput, ModelOutput,
...@@ -159,7 +159,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -159,7 +159,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
attn_score = attn_score - 1e30 * attn_mask attn_score = attn_score - 1e30 * attn_mask
# attention probability # attention probability
attn_prob = tf.nn.softmax(attn_score, axis=1) attn_prob = stable_softmax(attn_score, axis=1)
attn_prob = self.dropout(attn_prob, training=training) attn_prob = self.dropout(attn_prob, training=training)
......
...@@ -9,6 +9,8 @@ from .base import PIPELINE_INIT_ARGS, GenericTensor, Pipeline, PipelineException ...@@ -9,6 +9,8 @@ from .base import PIPELINE_INIT_ARGS, GenericTensor, Pipeline, PipelineException
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from ..tf_utils import stable_softmax
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -101,7 +103,7 @@ class FillMaskPipeline(Pipeline): ...@@ -101,7 +103,7 @@ class FillMaskPipeline(Pipeline):
outputs = outputs.numpy() outputs = outputs.numpy()
logits = outputs[0, masked_index, :] logits = outputs[0, masked_index, :]
probs = tf.nn.softmax(logits, axis=-1) probs = stable_softmax(logits, axis=-1)
if target_ids is not None: if target_ids is not None:
probs = tf.gather_nd(tf.squeeze(probs, 0), target_ids.reshape(-1, 1)) probs = tf.gather_nd(tf.squeeze(probs, 0), target_ids.reshape(-1, 1))
probs = tf.expand_dims(probs, 0) probs = tf.expand_dims(probs, 0)
......
...@@ -20,6 +20,7 @@ if is_tf_available(): ...@@ -20,6 +20,7 @@ if is_tf_available():
import tensorflow as tf import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
from ..tf_utils import stable_softmax
if is_torch_available(): if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
...@@ -103,7 +104,7 @@ class ImageClassificationPipeline(Pipeline): ...@@ -103,7 +104,7 @@ class ImageClassificationPipeline(Pipeline):
probs = model_outputs.logits.softmax(-1)[0] probs = model_outputs.logits.softmax(-1)[0]
scores, ids = probs.topk(top_k) scores, ids = probs.topk(top_k)
elif self.framework == "tf": elif self.framework == "tf":
probs = tf.nn.softmax(model_outputs.logits, axis=-1)[0] probs = stable_softmax(model_outputs.logits, axis=-1)[0]
topk = tf.math.top_k(probs, k=top_k) topk = tf.math.top_k(probs, k=top_k)
scores, ids = topk.values.numpy(), topk.indices.numpy() scores, ids = topk.values.numpy(), topk.indices.numpy()
else: else:
......
...@@ -22,6 +22,8 @@ if is_torch_available(): ...@@ -22,6 +22,8 @@ if is_torch_available():
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from ..tf_utils import stable_softmax
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -119,7 +121,7 @@ class ZeroShotImageClassificationPipeline(ChunkPipeline): ...@@ -119,7 +121,7 @@ class ZeroShotImageClassificationPipeline(ChunkPipeline):
scores = probs.tolist() scores = probs.tolist()
else: else:
logits = tf.concat([output["logits_per_image"] for output in model_outputs], axis=0) logits = tf.concat([output["logits_per_image"] for output in model_outputs], axis=0)
probs = tf.nn.softmax(logits, axis=0) probs = stable_softmax(logits, axis=0)
scores = probs.numpy().tolist() scores = probs.numpy().tolist()
result = [ result = [
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Union from typing import List, Optional, Union
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -44,3 +44,27 @@ def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]: ...@@ -44,3 +44,27 @@ def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
static = tensor.shape.as_list() static = tensor.shape.as_list()
return [dynamic[i] if s is None else s for i, s in enumerate(static)] return [dynamic[i] if s is None else s for i, s in enumerate(static)]
def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional[str] = None) -> tf.Tensor:
"""
Stable wrapper that returns the same output as `tf.nn.softmax`, but that works reliably with XLA on CPU. It is
meant as a workaround for the [following issue](https://github.com/tensorflow/tensorflow/issues/55682), and will be
removed after it gets fixed. The arguments and outputs are the same as `tf.nn.softmax`, and relies on the fact that
`softmax(x) = softmax(x + c)` (see https://ogunlao.github.io/2020/04/26/you_dont_really_know_softmax.html).
Args:
logits (`tf.Tensor`):
Must be one of the following types: half, float32, float64.
axis (`int`, *optional*):
The dimension softmax would be performed on. The default is -1 which indicates the last dimension.
name (`str`, *optional*):
A name for the operation.
Returns:
`tf.Tensor`:
A Tensor. Has the same type and shape as logits.
"""
# TODO: When the issue linked above gets sorted, add a check on TF version here and use the original function if
# it has the fix. After we drop the support for unfixed versions, remove this function.
return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
...@@ -53,7 +53,7 @@ from ...modeling_tf_utils import ( ...@@ -53,7 +53,7 @@ from ...modeling_tf_utils import (
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list, stable_softmax
from ...utils import logging from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
...@@ -244,7 +244,7 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer) ...@@ -244,7 +244,7 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
attention_scores = tf.add(attention_scores, attention_mask) attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1) attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
...@@ -1665,8 +1665,8 @@ from ...modeling_tf_utils import ( ...@@ -1665,8 +1665,8 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings, TFWrappedEmbeddings,
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
); from ...tf_utils import (shape_list,
) )
from ...tf_utils import shape_list, stable_softmax
from ...utils import logging from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
...@@ -1855,7 +1855,7 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer): ...@@ -1855,7 +1855,7 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import unittest import unittest
from transformers import GPT2Config, is_tf_available from transformers import GPT2Config, is_tf_available
from transformers.testing_utils import get_gpu_count, require_tf, slow from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
...@@ -536,8 +536,6 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -536,8 +536,6 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
self.assertListEqual(output_strings, expected_output_string) self.assertListEqual(output_strings, expected_output_string)
@slow @slow
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def test_lm_generate_gpt2_greedy_xla(self): def test_lm_generate_gpt2_greedy_xla(self):
# TODO (Joao): convert this to an example with a batch size>1 with different input lengths that works (and fix # TODO (Joao): convert this to an example with a batch size>1 with different input lengths that works (and fix
# the underlying problem) # the underlying problem)
...@@ -563,30 +561,33 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -563,30 +561,33 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
self.assertListEqual(output_strings, expected_output_strings) self.assertListEqual(output_strings, expected_output_strings)
@slow @slow
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def test_lm_generate_gpt2_sample_xla(self): def test_lm_generate_gpt2_sample_xla(self):
# NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same # NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
# output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible # output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible
# and that we can seed both versions. # and that we can seed both versions.
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
sentence = ["The dog"]
expected_output_string = [
"The dog must be well educated to do anything. If anything, this must be her best friend"
]
expected_output_string_xla = ["The dog has been named in connection with the murder of a 20-year-old man in!"]
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0]) # forces the generation to happen on CPU, to avoid GPU-related quirks
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) with tf.device(":/CPU:0"):
self.assertListEqual(output_strings, expected_output_string) model = TFGPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0]) tokenizer.pad_token = tokenizer.eos_token
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) tokenizer.padding_side = "left"
self.assertListEqual(output_strings, expected_output_string_xla)
sentence = ["The dog"]
expected_output_string = [
"The dog owner asked why did our vet decide there needed to be extra ventilation inside because most puppies"
]
expected_output_string_xla = [
"The dog has been named in connection with the murder of a 20-year-old man in!"
]
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string)
xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string_xla)
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import unittest import unittest
from transformers import T5Config, is_tf_available from transformers import T5Config, is_tf_available
from transformers.testing_utils import get_gpu_count, require_sentencepiece, require_tf, require_tokenizers, slow from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.utils import cached_property from transformers.utils import cached_property
from ..test_configuration_common import ConfigTester from ..test_configuration_common import ConfigTester
...@@ -481,8 +481,6 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -481,8 +481,6 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
@require_tokenizers @require_tokenizers
class TFT5GenerationIntegrationTests(unittest.TestCase): class TFT5GenerationIntegrationTests(unittest.TestCase):
@slow @slow
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def test_greedy_xla_generate_simple(self): def test_greedy_xla_generate_simple(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small") model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small") tokenizer = T5Tokenizer.from_pretrained("t5-small")
...@@ -534,30 +532,31 @@ class TFT5GenerationIntegrationTests(unittest.TestCase): ...@@ -534,30 +532,31 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(expected_output_string, output_strings) self.assertListEqual(expected_output_string, output_strings)
@slow @slow
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def test_sample_xla_generate_simple(self): def test_sample_xla_generate_simple(self):
# NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same # NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
# output out of the same seed is far from guaranteed (unlike this example). We can, however, confirm that the # output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible
# results are sensible and that we can seed both versions. # and that we can seed both versions.
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
sentence = "Translate English to German: I have two bananas"
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
expected_output_string = ["Ich habe 2 Bananen"]
expected_output_string_xla = ["Ich habe 2 Bananen"]
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(expected_output_string, output_strings)
xla_generate = tf.function(model.generate, jit_compile=True) # forces the generation to happen on CPU, to avoid GPU-related quirks
# seed set -> deterministic sampling sequence -> deterministic generation with tf.device(":/CPU:0"):
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0]) model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True) tokenizer = T5Tokenizer.from_pretrained("t5-small")
self.assertListEqual(expected_output_string_xla, output_strings_xla)
sentence = "Translate English to German: I have two bananas"
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
expected_output_string = ["Ich habe zwei Bananen"]
expected_output_string_xla = ["Ich habe 2 Bananen"]
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(expected_output_string, output_strings)
xla_generate = tf.function(model.generate, jit_compile=True)
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
self.assertListEqual(expected_output_string_xla, output_strings_xla)
@slow @slow
def test_sample_generate(self): def test_sample_generate(self):
......
...@@ -84,6 +84,7 @@ if is_tf_available(): ...@@ -84,6 +84,7 @@ if is_tf_available():
TFSampleEncoderDecoderOutput, TFSampleEncoderDecoderOutput,
) )
from transformers.modeling_tf_utils import unpack_inputs from transformers.modeling_tf_utils import unpack_inputs
from transformers.tf_utils import stable_softmax
if _tf_gpu_memory_limit is not None: if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU") gpus = tf.config.list_physical_devices("GPU")
...@@ -1709,6 +1710,41 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -1709,6 +1710,41 @@ class UtilsFunctionsTest(unittest.TestCase):
self.assertFalse(output[3]) self.assertFalse(output[3])
self.assertFalse(output[4]) self.assertFalse(output[4])
# Tests whether the stable softmax is stable on CPU, with and without XLA
def test_xla_stable_softmax(self):
large_penalty = -1e9
n_tokens = 10
batch_size = 8
def masked_softmax(x, boolean_mask):
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
masked_x = x + numerical_mask
return stable_softmax(masked_x)
xla_masked_softmax = tf.function(masked_softmax, jit_compile=True)
xla_stable_softmax = tf.function(stable_softmax, jit_compile=True)
x = tf.random.normal((batch_size, n_tokens))
# Same outcome regardless of the boolean mask here
masked_tokens = random.randint(0, n_tokens)
boolean_mask = tf.convert_to_tensor([[1] * (n_tokens - masked_tokens) + [0] * masked_tokens], dtype=tf.int32)
# We can randomly mask a random numerical input OUTSIDE XLA
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
masked_x = x + numerical_mask
xla_out = xla_stable_softmax(masked_x)
out = stable_softmax(masked_x)
assert tf.experimental.numpy.allclose(xla_out, out)
# The stable softmax has the same output as the original softmax
unstable_out = tf.nn.softmax(masked_x)
assert tf.experimental.numpy.allclose(unstable_out, out)
# We can randomly mask a random numerical input INSIDE XLA
xla_out = xla_masked_softmax(x, boolean_mask)
out = masked_softmax(x, boolean_mask)
assert tf.experimental.numpy.allclose(xla_out, out)
@require_tf @require_tf
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment