Unverified Commit 5b369dc5 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Remove assertion over possible activation functions in DistilBERT (#16066)

* Remove assertion over possible activation functions

* Same for TF and Flax
parent f5741bcd
......@@ -26,7 +26,7 @@ from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import gelu
from ...activations import get_activation
from ...deepspeed import is_deepspeed_zero3_enabled
from ...file_utils import (
add_code_sample_docstrings,
......@@ -231,8 +231,7 @@ class FFN(nn.Module):
self.seq_len_dim = 1
self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
assert config.activation in ["relu", "gelu"], f"activation ({config.activation}) must be in ['relu', 'gelu']"
self.activation = gelu if config.activation == "gelu" else nn.ReLU()
self.activation = get_activation(config.activation)
def forward(self, input):
return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
......@@ -564,6 +563,8 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.activation = get_activation(config.activation)
self.distilbert = DistilBertModel(config)
self.vocab_transform = nn.Linear(config.dim, config.dim)
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
......@@ -637,7 +638,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
)
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.activation(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size)
......
......@@ -261,10 +261,7 @@ class FlaxFFN(nn.Module):
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
assert self.config.activation in [
"relu",
"gelu",
], f"activation ({self.config.activation}) must be in ['relu', 'gelu']"
self.activation = ACT2FN[self.config.activation]
def __call__(self, hidden_states, deterministic: bool = True):
......@@ -576,7 +573,7 @@ class FlaxDistilBertForMaskedLMModule(nn.Module):
)
hidden_states = dlbrt_output[0]
prediction_logits = self.vocab_transform(hidden_states)
prediction_logits = ACT2FN["gelu"](prediction_logits)
prediction_logits = ACT2FN[self.config.activation](prediction_logits)
prediction_logits = self.vocab_layer_norm(prediction_logits)
if self.config.tie_word_embeddings:
......
......@@ -218,7 +218,6 @@ class TFFFN(tf.keras.layers.Layer):
self.lin2 = tf.keras.layers.Dense(
config.dim, kernel_initializer=get_initializer(config.initializer_range), name="lin2"
)
assert config.activation in ["relu", "gelu"], f"activation ({config.activation}) must be in ['relu', 'gelu']"
self.activation = get_tf_activation(config.activation)
def call(self, input, training=False):
......@@ -642,7 +641,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
self.vocab_transform = tf.keras.layers.Dense(
config.dim, kernel_initializer=get_initializer(config.initializer_range), name="vocab_transform"
)
self.act = get_tf_activation("gelu")
self.act = get_tf_activation(config.activation)
self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment