"tests/rembert/test_modeling_rembert.py" did not exist on "ba8b1f4754257e140ddabbe04a7f3e493e33802d"
Unverified Commit 0c1c42c1 authored by Philip May's avatar Philip May Committed by GitHub
Browse files

add `classifier_dropout` to classification heads (#12794)



* add classifier_dropout to Electra

* no type annotations yet
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add classifier_dropout to Electra

* add classifier_dropout to Electra ForTokenClass.

* add classifier_dropout to bert

* add classifier_dropout to roberta

* add classifier_dropout to big_bird

* add classifier_dropout to mobilebert

* empty commit to trigger CI

* add classifier_dropout to reformer

* add classifier_dropout to ConvBERT

* add classifier_dropout to Albert

* add classifier_dropout to Albert
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 9ff672fc
......@@ -2474,7 +2474,10 @@ class ReformerClassificationHead(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, hidden_states, **kwargs):
......
......@@ -498,7 +498,12 @@ class FlaxRobertaClassificationHead(nn.Module):
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
classifier_dropout = (
self.config.classifier_dropout
if self.config.classifier_dropout is not None
else self.config.hidden_dropout_prob
)
self.dropout = nn.Dropout(rate=classifier_dropout)
self.out_proj = nn.Dense(
self.config.num_labels,
dtype=self.dtype,
......@@ -877,7 +882,12 @@ class FlaxRobertaForTokenClassificationModule(nn.Module):
def setup(self):
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
classifier_dropout = (
self.config.classifier_dropout
if self.config.classifier_dropout is not None
else self.config.hidden_dropout_prob
)
self.dropout = nn.Dropout(rate=classifier_dropout)
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
......
......@@ -1346,7 +1346,10 @@ class RobertaForTokenClassification(RobertaPreTrainedModel):
self.num_labels = config.num_labels
self.roberta = RobertaModel(config, add_pooling_layer=False)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
......@@ -1427,7 +1430,10 @@ class RobertaClassificationHead(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
......
......@@ -933,7 +933,10 @@ class TFRobertaClassificationHead(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = tf.keras.layers.Dropout(classifier_dropout)
self.out_proj = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj"
)
......@@ -1206,7 +1209,10 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
self.num_labels = config.num_labels
self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = tf.keras.layers.Dropout(classifier_dropout)
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment