Commit a5381495 authored by Peter Izsak's avatar Peter Izsak Committed by Lysandre Debut
Browse files

Added classifier dropout rate in ALBERT

parent 83446a88
...@@ -76,6 +76,8 @@ class AlbertConfig(PretrainedConfig): ...@@ -76,6 +76,8 @@ class AlbertConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
classifier_dropout_prob (:obj:`float`, optional, defaults to 0.1):
The dropout ratio for attached classifiers.
Example:: Example::
...@@ -121,6 +123,7 @@ class AlbertConfig(PretrainedConfig): ...@@ -121,6 +123,7 @@ class AlbertConfig(PretrainedConfig):
type_vocab_size=2, type_vocab_size=2,
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
classifier_dropout_prob=0.1,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -140,3 +143,4 @@ class AlbertConfig(PretrainedConfig): ...@@ -140,3 +143,4 @@ class AlbertConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.classifier_dropout_prob = classifier_dropout_prob
...@@ -698,7 +698,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): ...@@ -698,7 +698,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.albert = AlbertModel(config) self.albert = AlbertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
self.init_weights() self.init_weights()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment