Unverified Commit 83fdd252 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Seq2Seq Templates] Correct some TF-serving errors and add gradient...

[Seq2Seq Templates] Correct some TF-serving errors and add gradient checkpointing to PT by default. (#9334)

* correct tests

* correct shape and get_tf_activation

* more correction tf

* add gradient checkpointing to templates

* correct typo
parent 8e74eca7
......@@ -74,6 +74,8 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
{% else -%}
vocab_size (:obj:`int`, `optional`, defaults to 50265):
Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the
......@@ -172,6 +174,7 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
init_std=0.02,
decoder_start_token_id=2,
classifier_dropout=0.0,
gradient_checkpointing=False,
{% endif -%}
pad_token_id=1,
bos_token_id=0,
......@@ -222,6 +225,7 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
{% endif -%}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
......
......@@ -20,6 +20,7 @@
import tensorflow as tf
from transformers.modeling_tf_outputs import TFCausalLMOutput
from ...activations_tf import get_tf_activation
from ...file_utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
......@@ -37,14 +38,14 @@ from ...modeling_tf_outputs import (
TFTokenClassifierOutput,
)
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
TFCausalLanguageModelingLoss,
TFSequenceSummary,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
......@@ -503,7 +504,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune):
"""Prunes heads of the model.
......@@ -1109,7 +1110,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
Returns:
tf.Tensor with dummy inputs
"""
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
return {"input_ids": tf.convert_to_tensor(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
......@@ -1404,7 +1405,7 @@ from typing import Dict, Optional, Tuple, Union
import tensorflow as tf
from ...activations_tf import ACT2FN
from ...activations_tf import get_tf_activation
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -1640,7 +1641,7 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
)
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.activation_fn = ACT2FN[config.activation_function]
self.activation_fn = get_tf_activation(config.activation_function)
self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout)
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
......@@ -1689,7 +1690,7 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
is_decoder=True,
)
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.activation_fn = ACT2FN[config.activation_function]
self.activation_fn = get_tf_activation(config.activation_function)
self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout)
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
......@@ -1782,8 +1783,8 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
@property
def dummy_inputs(self):
pad_token = 1
input_ids = tf.cast(tf.constant(DUMMY_INPUTS), tf.int32)
decoder_input_ids = tf.cast(tf.constant(DUMMY_INPUTS), tf.int32)
input_ids = tf.cast(tf.convert_to_tensor(DUMMY_INPUTS), tf.int32)
decoder_input_ids = tf.cast(tf.convert_to_tensor(DUMMY_INPUTS), tf.int32)
dummy_inputs = {
"decoder_input_ids": decoder_input_ids,
"attention_mask": tf.math.not_equal(input_ids, pad_token),
......@@ -2134,7 +2135,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
past_key_values_length = (
inputs["past_key_values"][0][0].shape[2] if inputs["past_key_values"] is not None else 0
shape_list(inputs["past_key_values"][0][0])[2] if inputs["past_key_values"] is not None else 0
)
# embed positions
......@@ -2390,7 +2391,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
# {{cookiecutter.uppercase_modelname}} is a special case where the bias has two dimensions
# and not named just `bias`
if new_num_tokens is not None:
num_tokens_to_copy = min(self.final_logits_bias.shape[0], new_num_tokens)
num_tokens_to_copy = min(shape_list(self.final_logits_bias)[0], new_num_tokens)
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy]
self.final_logits_bias = self.add_weight(
......
......@@ -25,6 +25,7 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -48,7 +49,6 @@ from ...modeling_utils import (
prune_linear_layer,
)
from ...utils import logging
from ...activations import ACT2FN
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
......@@ -1809,7 +1809,13 @@ class {{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
return hidden_states, attn_weights
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
......@@ -1846,7 +1852,8 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[torch.Tensor] = False,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
):
"""
Args:
......@@ -1907,12 +1914,15 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
return (
hidden_states,
self_attn_weights,
present_key_value,
cross_attn_weights,
)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if use_cache:
outputs += (present_key_value,)
return outputs
# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->{{cookiecutter.camelcase_modelname}}
......@@ -2178,12 +2188,28 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop): # skip the layer
attn = None
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
)
else:
hidden_states, attn = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions)
layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (attn,)
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
......@@ -2355,21 +2381,46 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
past_key_value = past_key_values[idx] if past_key_values is not None else None
hidden_states, layer_self_attn, present_key_value, layer_cross_attn = decoder_layer(
if getattr(self.config, "gradient_checkpointing", False):
if use_cache:
raise ValueError(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
)
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
combined_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (present_key_value,)
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_self_attn,)
all_cross_attentions += (layer_cross_attn,)
all_self_attns += (layer_outputs[1],)
all_cross_attentions += (layer_outputs[2],)
# add hidden states from the last decoder layer
if output_hidden_states:
......
......@@ -532,7 +532,7 @@ class TF{{cookiecutter.camelcase_modelname}}ModelIntegrationTest(unittest.TestCa
expected_slice = tf.Tensor(
[[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]],
)
self.assertTrue(tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=TOLERANCE))
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=TOLERANCE)
def test_inference_with_head(self):
model = TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
......@@ -547,7 +547,7 @@ class TF{{cookiecutter.camelcase_modelname}}ModelIntegrationTest(unittest.TestCa
expected_slice = tf.Tensor(
[[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]],
)
self.assertTrue(tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=TOLERANCE))
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=TOLERANCE)
def test_seq_to_seq_generation(self):
hf = TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
......
......@@ -683,23 +683,6 @@ class {{cookiecutter.camelcase_modelname}}ModelTest(ModelTesterMixin, Generation
def test_config(self):
self.config_tester.run_common_tests()
def test_initialization_more(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
model = {{cookiecutter.camelcase_modelname}}Model(config)
model.to(torch_device)
model.eval()
# test init
self.assertTrue((model.encoder.embed_tokens.weight == model.shared.weight).all().item())
def _check_var(module):
"""Check that we initialized various parameters from N(0, config.init_std)."""
self.assertAlmostEqual(torch.std(module.weight).item(), config.init_std, 2)
_check_var(model.encoder.embed_tokens)
_check_var(model.encoder.layers[0].self_attn.k_proj)
_check_var(model.encoder.layers[0].fc1)
_check_var(model.encoder.embed_positions)
def test_save_load_strict(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment