"examples/run_bert_classifier.py" did not exist on "d0673c7dbd4945118946e6f8e8bcb53cecfc7588"
Commit fc5a38ac authored by Grégory Châtel's avatar Grégory Châtel
Browse files

Adding the BertForMultipleChoiceClass.

parent c45d8ac5
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from .modeling import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForTokenClassification,
BertForQuestionAnswering)
BertForSequenceClassification, BertForMultipleChoice,
BertForTokenClassification, BertForQuestionAnswering)
from .optimization import BertAdam
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
......@@ -877,6 +877,75 @@ class BertForSequenceClassification(PreTrainedBertModel):
return logits
class BertForMultipleChoice(PreTrainedBertModel):
"""BERT model for multiple choice tasks.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`num_choices`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_choices].
Outputs:
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits of shape [batch_size, num_labels].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_choices = 2
model = BertForMultipleChoice(config, num_choices)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_choices=2):
super(BertForMultipleChoice, self).__init__(config)
self.num_choices = num_choices
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
_, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
return loss
else:
return reshaped_logits
class BertForTokenClassification(PreTrainedBertModel):
"""BERT model for token-level classification.
This module is composed of the BERT model with a linear layer on top of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment