Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fc5a38ac
Commit
fc5a38ac
authored
Dec 06, 2018
by
Grégory Châtel
Browse files
Adding the BertForMultipleChoiceClass.
parent
c45d8ac5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
71 additions
and
2 deletions
+71
-2
pytorch_pretrained_bert/__init__.py
pytorch_pretrained_bert/__init__.py
+2
-2
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+69
-0
No files found.
pytorch_pretrained_bert/__init__.py
View file @
fc5a38ac
from
.tokenization
import
BertTokenizer
,
BasicTokenizer
,
WordpieceTokenizer
from
.tokenization
import
BertTokenizer
,
BasicTokenizer
,
WordpieceTokenizer
from
.modeling
import
(
BertConfig
,
BertModel
,
BertForPreTraining
,
from
.modeling
import
(
BertConfig
,
BertModel
,
BertForPreTraining
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForSequenceClassification
,
BertFor
TokenClassification
,
BertForSequenceClassification
,
BertFor
MultipleChoice
,
BertForQuestionAnswering
)
BertForTokenClassification
,
BertForQuestionAnswering
)
from
.optimization
import
BertAdam
from
.optimization
import
BertAdam
from
.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
pytorch_pretrained_bert/modeling.py
View file @
fc5a38ac
...
@@ -877,6 +877,75 @@ class BertForSequenceClassification(PreTrainedBertModel):
...
@@ -877,6 +877,75 @@ class BertForSequenceClassification(PreTrainedBertModel):
return
logits
return
logits
class
BertForMultipleChoice
(
PreTrainedBertModel
):
"""BERT model for multiple choice tasks.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`num_choices`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_choices].
Outputs:
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits of shape [batch_size, num_labels].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_choices = 2
model = BertForMultipleChoice(config, num_choices)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
num_choices
=
2
):
super
(
BertForMultipleChoice
,
self
).
__init__
(
config
)
self
.
num_choices
=
num_choices
self
.
bert
=
BertModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
self
.
apply
(
self
.
init_bert_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
):
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
_
,
pooled_output
=
self
.
bert
(
flat_input_ids
,
flat_token_type_ids
,
flat_attention_mask
,
output_all_encoded_layers
=
False
)
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
reshaped_logits
=
logits
.
view
(
-
1
,
self
.
num_choices
)
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
reshaped_logits
,
labels
)
return
loss
else
:
return
reshaped_logits
class
BertForTokenClassification
(
PreTrainedBertModel
):
class
BertForTokenClassification
(
PreTrainedBertModel
):
"""BERT model for token-level classification.
"""BERT model for token-level classification.
This module is composed of the BERT model with a linear layer on top of
This module is composed of the BERT model with a linear layer on top of
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment