Commit 80f995a1 authored by thomwolf's avatar thomwolf
Browse files

revert BertForMultipleChoice linear classifier

parent 38ba7b43
...@@ -1034,7 +1034,7 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1034,7 +1034,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
self.num_choices = num_choices self.num_choices = num_choices
self.bert = BertModel(config) self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_choices) self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment