Unverified Commit 1e3c9dda authored by Thien Tran's avatar Thien Tran Committed by GitHub
Browse files

Make Whisper Encoder's sinusoidal PE non-trainable by default (#26032)



* set encoder's PE as non-trainable

* freeze flax

* init sinusoids

* add test for non-trainable embed positions

* simplify TF encoder embed_pos

* revert tf

* clean up

* add sinusoidal init for jax

* make consistent sinusoidal function

* fix dtype

* add default dtype

* use numpy for sinusoids. fix jax

* add sinusoid init for TF

* fix

* use custom embedding

* use specialized init for each impl

* fix sinusoids init. add test for pytorch

* fix TF dtype

* simplify sinusoid init for flax and tf

* add tests for TF

* change default dtype to float32

* add sinusoid test for flax

* Update src/transformers/models/whisper/modeling_flax_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/whisper/modeling_tf_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* move sinusoidal init to _init_weights

---------
Co-authored-by: default avatarsanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent fc639143
......@@ -14,6 +14,7 @@
# limitations under the License.
""" Flax whisper model."""
import math
import random
from functools import partial
from typing import Optional, Tuple
......@@ -58,6 +59,19 @@ _CONFIG_FOR_DOC = "WhisperConfig"
remat = nn_partitioning.remat
def sinusoidal_embedding_init(key, shape, dtype=jnp.float_) -> jax.Array:
"""Returns sinusoids for positional embedding"""
length, channels = shape
if channels % 2 != 0:
raise ValueError(
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
)
log_timescale_increment = math.log(10000) / (channels // 2 - 1)
inv_timescales = jnp.exp(-log_timescale_increment * jnp.arange(channels // 2))
scaled_time = jnp.arange(length).reshape(-1, 1) * inv_timescales.reshape(1, -1)
return jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1).astype(dtype)
WHISPER_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
......@@ -649,7 +663,13 @@ class FlaxWhisperEncoder(nn.Module):
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.embed_positions = nn.Embed(self.config.max_source_positions, self.config.d_model, dtype=self.dtype)
self.embed_positions = nn.Embed(
self.config.max_source_positions,
self.config.d_model,
dtype=self.dtype,
embedding_init=sinusoidal_embedding_init,
)
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
......@@ -673,6 +693,8 @@ class FlaxWhisperEncoder(nn.Module):
hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)
embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions))
# freeze the sinusoidal embeddings by stopping the back-prop
embed_positions = jax.lax.stop_gradient(embed_positions)
hidden_states = hidden_states + embed_positions
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
......
......@@ -59,6 +59,19 @@ TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
LARGE_NEGATIVE = -1e8
def sinusoidal_embedding_init(shape, dtype=tf.float32) -> tf.Tensor:
"""Returns sinusoids for positional embedding"""
length, channels = shape
if channels % 2 != 0:
raise ValueError(
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
)
log_timescale_increment = math.log(10000) / (channels // 2 - 1)
inv_timescales = tf.exp(-log_timescale_increment * tf.range(channels // 2, dtype=tf.float32))
scaled_time = tf.reshape(tf.range(length, dtype=tf.float32), (-1, 1)) * tf.reshape(inv_timescales, (1, -1))
return tf.cast(tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1), dtype)
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
......@@ -117,16 +130,25 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
class TFWhisperPositionalEmbedding(tf.keras.layers.Layer):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None, **kwargs):
def __init__(
self,
num_positions: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
embedding_initializer=None,
**kwargs,
):
super().__init__(**kwargs)
self.num_positions = num_positions
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.embedding_initializer = tf.keras.initializers.get(embedding_initializer)
def build(self, input_shape):
self.weight = self.add_weight(
name="weight",
shape=[self.num_positions, self.embedding_dim],
initializer=self.embedding_initializer,
trainable=True,
)
super().build(input_shape)
......@@ -620,8 +642,12 @@ class TFWhisperEncoder(tf.keras.layers.Layer):
self.conv2 = tf.keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=2, padding="valid", name="conv2")
self.embed_positions = TFWhisperPositionalEmbedding(
self.max_source_positions, self.embed_dim, name="embed_positions"
num_positions=self.max_source_positions,
embedding_dim=self.embed_dim,
embedding_initializer=sinusoidal_embedding_init,
name="embed_positions",
)
self.embed_positions.trainable = False
self.encoder_layers = [TFWhisperEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
......
......@@ -55,6 +55,18 @@ WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
"""Returns sinusoids for positional embedding"""
if channels % 2 != 0:
raise ValueError(
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
)
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
......@@ -668,6 +680,10 @@ class WhisperPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, WhisperEncoder):
with torch.no_grad():
embed_positions = module.embed_positions.weight
embed_positions.copy_(sinusoids(*embed_positions.shape))
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (WhisperDecoder, WhisperEncoder)):
......@@ -835,6 +851,7 @@ class WhisperEncoder(WhisperPreTrainedModel):
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
self.embed_positions.requires_grad_(False)
self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
......
......@@ -46,6 +46,7 @@ if is_flax_available():
WhisperProcessor,
)
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
from transformers.models.whisper.modeling_flax_whisper import sinusoidal_embedding_init
@require_flax
......@@ -387,6 +388,19 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase):
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_encoder_sinusoidal_embed_positions(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
params = model.params
if model.base_model_prefix in params:
params = model.params[model.base_model_prefix]
embeds = params["encoder"]["embed_positions"]["embedding"]
sinusoids = sinusoidal_embedding_init(None, embeds.shape)
self.assertTrue(jax.numpy.allclose(embeds, sinusoids))
@slow
@require_flax
......
......@@ -42,7 +42,11 @@ if is_tf_available():
import tensorflow as tf
from transformers import TFWhisperForConditionalGeneration, TFWhisperModel, set_seed
from transformers.models.whisper.modeling_tf_whisper import TFWhisperDecoder, TFWhisperEncoder
from transformers.models.whisper.modeling_tf_whisper import (
TFWhisperDecoder,
TFWhisperEncoder,
sinusoidal_embedding_init,
)
def prepare_whisper_inputs_dict(
......@@ -297,6 +301,23 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs)
def test_requires_grad_encoder_embed_positions(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
encoder = model.get_encoder()
self.assertFalse(encoder.embed_positions.trainable)
def test_encoder_sinusoidal_embed_positions(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
model.build()
embeds = model.get_encoder().embed_positions.get_weights()[0]
sinusoids = sinusoidal_embedding_init(embeds.shape).numpy()
self.assertTrue(np.allclose(embeds, sinusoids))
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
......
......@@ -49,7 +49,7 @@ if is_torch_available():
WhisperProcessor,
set_seed,
)
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids
if is_flax_available():
import jax.numpy as jnp
......@@ -351,6 +351,20 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
self.assertFalse(all(encoder_grads))
self.assertTrue(all(decoder_grads))
def test_requires_grad_encoder_embed_positions(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
encoder = model.get_encoder()
self.assertFalse(encoder.embed_positions.weight.requires_grad)
def test_encoder_sinusoidal_embed_positions(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
embeds = model.get_encoder().embed_positions.weight
self.assertTrue(torch.allclose(embeds, sinusoids(*embeds.shape)))
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment