Unverified Commit 5b369dc5 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Remove assertion over possible activation functions in DistilBERT (#16066)

* Remove assertion over possible activation functions

* Same for TF and Flax
parent f5741bcd
...@@ -26,7 +26,7 @@ from packaging import version ...@@ -26,7 +26,7 @@ from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import gelu from ...activations import get_activation
from ...deepspeed import is_deepspeed_zero3_enabled from ...deepspeed import is_deepspeed_zero3_enabled
from ...file_utils import ( from ...file_utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -231,8 +231,7 @@ class FFN(nn.Module): ...@@ -231,8 +231,7 @@ class FFN(nn.Module):
self.seq_len_dim = 1 self.seq_len_dim = 1
self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim) self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim) self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
assert config.activation in ["relu", "gelu"], f"activation ({config.activation}) must be in ['relu', 'gelu']" self.activation = get_activation(config.activation)
self.activation = gelu if config.activation == "gelu" else nn.ReLU()
def forward(self, input): def forward(self, input):
return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input) return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
...@@ -564,6 +563,8 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -564,6 +563,8 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.activation = get_activation(config.activation)
self.distilbert = DistilBertModel(config) self.distilbert = DistilBertModel(config)
self.vocab_transform = nn.Linear(config.dim, config.dim) self.vocab_transform = nn.Linear(config.dim, config.dim)
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12) self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
...@@ -637,7 +638,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -637,7 +638,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
) )
hidden_states = dlbrt_output[0] # (bs, seq_length, dim) hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim) prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim) prediction_logits = self.activation(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim) prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size) prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size)
......
...@@ -261,10 +261,7 @@ class FlaxFFN(nn.Module): ...@@ -261,10 +261,7 @@ class FlaxFFN(nn.Module):
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
) )
assert self.config.activation in [
"relu",
"gelu",
], f"activation ({self.config.activation}) must be in ['relu', 'gelu']"
self.activation = ACT2FN[self.config.activation] self.activation = ACT2FN[self.config.activation]
def __call__(self, hidden_states, deterministic: bool = True): def __call__(self, hidden_states, deterministic: bool = True):
...@@ -576,7 +573,7 @@ class FlaxDistilBertForMaskedLMModule(nn.Module): ...@@ -576,7 +573,7 @@ class FlaxDistilBertForMaskedLMModule(nn.Module):
) )
hidden_states = dlbrt_output[0] hidden_states = dlbrt_output[0]
prediction_logits = self.vocab_transform(hidden_states) prediction_logits = self.vocab_transform(hidden_states)
prediction_logits = ACT2FN["gelu"](prediction_logits) prediction_logits = ACT2FN[self.config.activation](prediction_logits)
prediction_logits = self.vocab_layer_norm(prediction_logits) prediction_logits = self.vocab_layer_norm(prediction_logits)
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
......
...@@ -218,7 +218,6 @@ class TFFFN(tf.keras.layers.Layer): ...@@ -218,7 +218,6 @@ class TFFFN(tf.keras.layers.Layer):
self.lin2 = tf.keras.layers.Dense( self.lin2 = tf.keras.layers.Dense(
config.dim, kernel_initializer=get_initializer(config.initializer_range), name="lin2" config.dim, kernel_initializer=get_initializer(config.initializer_range), name="lin2"
) )
assert config.activation in ["relu", "gelu"], f"activation ({config.activation}) must be in ['relu', 'gelu']"
self.activation = get_tf_activation(config.activation) self.activation = get_tf_activation(config.activation)
def call(self, input, training=False): def call(self, input, training=False):
...@@ -642,7 +641,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel ...@@ -642,7 +641,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
self.vocab_transform = tf.keras.layers.Dense( self.vocab_transform = tf.keras.layers.Dense(
config.dim, kernel_initializer=get_initializer(config.initializer_range), name="vocab_transform" config.dim, kernel_initializer=get_initializer(config.initializer_range), name="vocab_transform"
) )
self.act = get_tf_activation("gelu") self.act = get_tf_activation(config.activation)
self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm") self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector") self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment