"vscode:/vscode.git/clone" did not exist on "3f290e6c8403c6a2cf80dce068869793bde49540"
Commit 38ba7b43 authored by David Pollack's avatar David Pollack
Browse files

fixed BertForMultipleChoice model init and forward pass

parent 61674333
......@@ -1034,13 +1034,13 @@ class BertForMultipleChoice(BertPreTrainedModel):
self.num_choices = num_choices
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.classifier = nn.Linear(config.hidden_size, num_choices)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
_, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment