"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "63caa370e6c618dbe7d3fd4cbf545cc32eca1a15"
Unverified Commit 1e3c9dda authored by Thien Tran's avatar Thien Tran Committed by GitHub
Browse files

Make Whisper Encoder's sinusoidal PE non-trainable by default (#26032)



* set encoder's PE as non-trainable

* freeze flax

* init sinusoids

* add test for non-trainable embed positions

* simplify TF encoder embed_pos

* revert tf

* clean up

* add sinusoidal init for jax

* make consistent sinusoidal function

* fix dtype

* add default dtype

* use numpy for sinusoids. fix jax

* add sinusoid init for TF

* fix

* use custom embedding

* use specialized init for each impl

* fix sinusoids init. add test for pytorch

* fix TF dtype

* simplify sinusoid init for flax and tf

* add tests for TF

* change default dtype to float32

* add sinusoid test for flax

* Update src/transformers/models/whisper/modeling_flax_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/whisper/modeling_tf_whisper.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* move sinusoidal init to _init_weights

---------
Co-authored-by: default avatarsanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent fc639143
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" Flax whisper model.""" """ Flax whisper model."""
import math
import random import random
from functools import partial from functools import partial
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -58,6 +59,19 @@ _CONFIG_FOR_DOC = "WhisperConfig" ...@@ -58,6 +59,19 @@ _CONFIG_FOR_DOC = "WhisperConfig"
remat = nn_partitioning.remat remat = nn_partitioning.remat
def sinusoidal_embedding_init(key, shape, dtype=jnp.float_) -> jax.Array:
"""Returns sinusoids for positional embedding"""
length, channels = shape
if channels % 2 != 0:
raise ValueError(
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
)
log_timescale_increment = math.log(10000) / (channels // 2 - 1)
inv_timescales = jnp.exp(-log_timescale_increment * jnp.arange(channels // 2))
scaled_time = jnp.arange(length).reshape(-1, 1) * inv_timescales.reshape(1, -1)
return jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1).astype(dtype)
WHISPER_START_DOCSTRING = r""" WHISPER_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
...@@ -649,7 +663,13 @@ class FlaxWhisperEncoder(nn.Module): ...@@ -649,7 +663,13 @@ class FlaxWhisperEncoder(nn.Module):
dtype=self.dtype, dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing, gradient_checkpointing=self.gradient_checkpointing,
) )
self.embed_positions = nn.Embed(self.config.max_source_positions, self.config.d_model, dtype=self.dtype)
self.embed_positions = nn.Embed(
self.config.max_source_positions,
self.config.d_model,
dtype=self.dtype,
embedding_init=sinusoidal_embedding_init,
)
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
...@@ -673,6 +693,8 @@ class FlaxWhisperEncoder(nn.Module): ...@@ -673,6 +693,8 @@ class FlaxWhisperEncoder(nn.Module):
hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False) hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)
embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions)) embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions))
# freeze the sinusoidal embeddings by stopping the back-prop
embed_positions = jax.lax.stop_gradient(embed_positions)
hidden_states = hidden_states + embed_positions hidden_states = hidden_states + embed_positions
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
......
...@@ -59,6 +59,19 @@ TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -59,6 +59,19 @@ TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
LARGE_NEGATIVE = -1e8 LARGE_NEGATIVE = -1e8
def sinusoidal_embedding_init(shape, dtype=tf.float32) -> tf.Tensor:
"""Returns sinusoids for positional embedding"""
length, channels = shape
if channels % 2 != 0:
raise ValueError(
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
)
log_timescale_increment = math.log(10000) / (channels // 2 - 1)
inv_timescales = tf.exp(-log_timescale_increment * tf.range(channels // 2, dtype=tf.float32))
scaled_time = tf.reshape(tf.range(length, dtype=tf.float32), (-1, 1)) * tf.reshape(inv_timescales, (1, -1))
return tf.cast(tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1), dtype)
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
pad_token_id = tf.cast(pad_token_id, input_ids.dtype) pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
...@@ -117,16 +130,25 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): ...@@ -117,16 +130,25 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
class TFWhisperPositionalEmbedding(tf.keras.layers.Layer): class TFWhisperPositionalEmbedding(tf.keras.layers.Layer):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None, **kwargs): def __init__(
self,
num_positions: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
embedding_initializer=None,
**kwargs,
):
super().__init__(**kwargs) super().__init__(**kwargs)
self.num_positions = num_positions self.num_positions = num_positions
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.padding_idx = padding_idx self.padding_idx = padding_idx
self.embedding_initializer = tf.keras.initializers.get(embedding_initializer)
def build(self, input_shape): def build(self, input_shape):
self.weight = self.add_weight( self.weight = self.add_weight(
name="weight", name="weight",
shape=[self.num_positions, self.embedding_dim], shape=[self.num_positions, self.embedding_dim],
initializer=self.embedding_initializer,
trainable=True, trainable=True,
) )
super().build(input_shape) super().build(input_shape)
...@@ -620,8 +642,12 @@ class TFWhisperEncoder(tf.keras.layers.Layer): ...@@ -620,8 +642,12 @@ class TFWhisperEncoder(tf.keras.layers.Layer):
self.conv2 = tf.keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=2, padding="valid", name="conv2") self.conv2 = tf.keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=2, padding="valid", name="conv2")
self.embed_positions = TFWhisperPositionalEmbedding( self.embed_positions = TFWhisperPositionalEmbedding(
self.max_source_positions, self.embed_dim, name="embed_positions" num_positions=self.max_source_positions,
embedding_dim=self.embed_dim,
embedding_initializer=sinusoidal_embedding_init,
name="embed_positions",
) )
self.embed_positions.trainable = False
self.encoder_layers = [TFWhisperEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] self.encoder_layers = [TFWhisperEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
......
...@@ -55,6 +55,18 @@ WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -55,6 +55,18 @@ WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
"""Returns sinusoids for positional embedding"""
if channels % 2 != 0:
raise ValueError(
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
)
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
""" """
...@@ -668,6 +680,10 @@ class WhisperPreTrainedModel(PreTrainedModel): ...@@ -668,6 +680,10 @@ class WhisperPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std) module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
elif isinstance(module, WhisperEncoder):
with torch.no_grad():
embed_positions = module.embed_positions.weight
embed_positions.copy_(sinusoids(*embed_positions.shape))
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (WhisperDecoder, WhisperEncoder)): if isinstance(module, (WhisperDecoder, WhisperEncoder)):
...@@ -835,6 +851,7 @@ class WhisperEncoder(WhisperPreTrainedModel): ...@@ -835,6 +851,7 @@ class WhisperEncoder(WhisperPreTrainedModel):
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
self.embed_positions.requires_grad_(False)
self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model)
......
...@@ -46,6 +46,7 @@ if is_flax_available(): ...@@ -46,6 +46,7 @@ if is_flax_available():
WhisperProcessor, WhisperProcessor,
) )
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
from transformers.models.whisper.modeling_flax_whisper import sinusoidal_embedding_init
@require_flax @require_flax
...@@ -387,6 +388,19 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -387,6 +388,19 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase):
max_diff = (base_params[key] - base_params_from_head[key]).sum().item() max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_encoder_sinusoidal_embed_positions(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
params = model.params
if model.base_model_prefix in params:
params = model.params[model.base_model_prefix]
embeds = params["encoder"]["embed_positions"]["embedding"]
sinusoids = sinusoidal_embedding_init(None, embeds.shape)
self.assertTrue(jax.numpy.allclose(embeds, sinusoids))
@slow @slow
@require_flax @require_flax
......
...@@ -42,7 +42,11 @@ if is_tf_available(): ...@@ -42,7 +42,11 @@ if is_tf_available():
import tensorflow as tf import tensorflow as tf
from transformers import TFWhisperForConditionalGeneration, TFWhisperModel, set_seed from transformers import TFWhisperForConditionalGeneration, TFWhisperModel, set_seed
from transformers.models.whisper.modeling_tf_whisper import TFWhisperDecoder, TFWhisperEncoder from transformers.models.whisper.modeling_tf_whisper import (
TFWhisperDecoder,
TFWhisperEncoder,
sinusoidal_embedding_init,
)
def prepare_whisper_inputs_dict( def prepare_whisper_inputs_dict(
...@@ -297,6 +301,23 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC ...@@ -297,6 +301,23 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs) self.model_tester.create_and_check_model_forward(*config_and_inputs)
def test_requires_grad_encoder_embed_positions(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
encoder = model.get_encoder()
self.assertFalse(encoder.embed_positions.trainable)
def test_encoder_sinusoidal_embed_positions(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
model.build()
embeds = model.get_encoder().embed_positions.get_weights()[0]
sinusoids = sinusoidal_embedding_init(embeds.shape).numpy()
self.assertTrue(np.allclose(embeds, sinusoids))
def test_decoder_model_past_with_large_inputs(self): def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
......
...@@ -49,7 +49,7 @@ if is_torch_available(): ...@@ -49,7 +49,7 @@ if is_torch_available():
WhisperProcessor, WhisperProcessor,
set_seed, set_seed,
) )
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids
if is_flax_available(): if is_flax_available():
import jax.numpy as jnp import jax.numpy as jnp
...@@ -351,6 +351,20 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -351,6 +351,20 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
self.assertFalse(all(encoder_grads)) self.assertFalse(all(encoder_grads))
self.assertTrue(all(decoder_grads)) self.assertTrue(all(decoder_grads))
def test_requires_grad_encoder_embed_positions(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
encoder = model.get_encoder()
self.assertFalse(encoder.embed_positions.weight.requires_grad)
def test_encoder_sinusoidal_embed_positions(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
embeds = model.get_encoder().embed_positions.weight
self.assertTrue(torch.allclose(embeds, sinusoids(*embeds.shape)))
def test_decoder_model_past_with_large_inputs(self): def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment