Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f9f8a531
Unverified
Commit
f9f8a531
authored
Jun 15, 2020
by
Sylvain Gugger
Committed by
GitHub
Jun 15, 2020
Browse files
Add DistilBertForMultipleChoice (#5032)
* Add `DistilBertForMultipleChoice`
parent
36434220
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
144 additions
and
1 deletion
+144
-1
docs/source/model_doc/distilbert.rst
docs/source/model_doc/distilbert.rst
+7
-0
src/transformers/__init__.py
src/transformers/__init__.py
+1
-0
src/transformers/configuration_distilbert.py
src/transformers/configuration_distilbert.py
+1
-1
src/transformers/modeling_auto.py
src/transformers/modeling_auto.py
+2
-0
src/transformers/modeling_distilbert.py
src/transformers/modeling_distilbert.py
+108
-0
tests/test_modeling_distilbert.py
tests/test_modeling_distilbert.py
+25
-0
No files found.
docs/source/model_doc/distilbert.rst
View file @
f9f8a531
...
...
@@ -75,6 +75,13 @@ DistilBertForSequenceClassification
:members:
DistilBertForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DistilBertForMultipleChoice
:members:
DistilBertForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...
...
src/transformers/__init__.py
View file @
f9f8a531
...
...
@@ -271,6 +271,7 @@ if is_torch_available():
DistilBertPreTrainedModel
,
DistilBertForMaskedLM
,
DistilBertModel
,
DistilBertForMultipleChoice
,
DistilBertForSequenceClassification
,
DistilBertForQuestionAnswering
,
DistilBertForTokenClassification
,
...
...
src/transformers/configuration_distilbert.py
View file @
f9f8a531
...
...
@@ -75,7 +75,7 @@ class DistilBertConfig(PretrainedConfig):
The dropout probabilities used in the question answering model
:class:`~transformers.DistilBertForQuestionAnswering`.
seq_classif_dropout (:obj:`float`, optional, defaults to 0.2):
The dropout probabilities used in the sequence classification model
The dropout probabilities used in the sequence classification
and the multiple choice
model
:class:`~transformers.DistilBertForSequenceClassification`.
Example::
...
...
src/transformers/modeling_auto.py
View file @
f9f8a531
...
...
@@ -78,6 +78,7 @@ from .modeling_camembert import (
from
.modeling_ctrl
import
CTRLLMHeadModel
,
CTRLModel
from
.modeling_distilbert
import
(
DistilBertForMaskedLM
,
DistilBertForMultipleChoice
,
DistilBertForQuestionAnswering
,
DistilBertForSequenceClassification
,
DistilBertForTokenClassification
,
...
...
@@ -314,6 +315,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
(
LongformerConfig
,
LongformerForMultipleChoice
),
(
RobertaConfig
,
RobertaForMultipleChoice
),
(
BertConfig
,
BertForMultipleChoice
),
(
DistilBertConfig
,
DistilBertForMultipleChoice
),
(
XLNetConfig
,
XLNetForMultipleChoice
),
(
AlbertConfig
,
AlbertForMultipleChoice
),
]
...
...
src/transformers/modeling_distilbert.py
View file @
f9f8a531
...
...
@@ -864,3 +864,111 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), scores, (hidden_states), (attentions)
@
add_start_docstrings
(
"""DistilBert Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """
,
DISTILBERT_START_DOCSTRING
,
)
class
DistilBertForMultipleChoice
(
DistilBertPreTrainedModel
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
distilbert
=
DistilBertModel
(
config
)
self
.
pre_classifier
=
nn
.
Linear
(
config
.
dim
,
config
.
dim
)
self
.
classifier
=
nn
.
Linear
(
config
.
dim
,
1
)
self
.
dropout
=
nn
.
Dropout
(
config
.
seq_classif_dropout
)
self
.
init_weights
()
@
add_start_docstrings_to_callable
(
DISTILBERT_INPUTS_DOCSTRING
.
format
(
"(batch_size, num_choices, sequence_length)"
))
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
):
r
"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import DistilBertTokenizer, DistilBertForMultipleChoice
import torch
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model = DistilBertForMultipleChoice.from_pretrained('distilbert-base-cased')
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
choice0 = "It is eaten with a fork and a knife."
choice1 = "It is eaten while held in the hand."
labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='pt', pad_to_max_length=True)
outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1
# the linear classifier still needs to be trained
loss, logits = outputs[:2]
"""
num_choices
=
input_ids
.
shape
[
1
]
if
input_ids
is
not
None
else
inputs_embeds
.
shape
[
1
]
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
if
input_ids
is
not
None
else
None
attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
inputs_embeds
=
(
inputs_embeds
.
view
(
-
1
,
inputs_embeds
.
size
(
-
2
),
inputs_embeds
.
size
(
-
1
))
if
inputs_embeds
is
not
None
else
None
)
outputs
=
self
.
distilbert
(
input_ids
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
)
hidden_state
=
outputs
[
0
]
# (bs * num_choices, seq_len, dim)
pooled_output
=
hidden_state
[:,
0
]
# (bs * num_choices, dim)
pooled_output
=
self
.
pre_classifier
(
pooled_output
)
# (bs * num_choices, dim)
pooled_output
=
nn
.
ReLU
()(
pooled_output
)
# (bs * num_choices, dim)
pooled_output
=
self
.
dropout
(
pooled_output
)
# (bs * num_choices, dim)
logits
=
self
.
classifier
(
pooled_output
)
# (bs * num_choices, 1)
reshaped_logits
=
logits
.
view
(
-
1
,
num_choices
)
# (bs, num_choices)
outputs
=
(
reshaped_logits
,)
+
outputs
[
1
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
reshaped_logits
,
labels
)
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), reshaped_logits, (hidden_states), (attentions)
tests/test_modeling_distilbert.py
View file @
f9f8a531
...
...
@@ -28,6 +28,7 @@ if is_torch_available():
DistilBertConfig
,
DistilBertModel
,
DistilBertForMaskedLM
,
DistilBertForMultipleChoice
,
DistilBertForTokenClassification
,
DistilBertForQuestionAnswering
,
DistilBertForSequenceClassification
,
...
...
@@ -41,6 +42,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
(
DistilBertModel
,
DistilBertForMaskedLM
,
DistilBertForMultipleChoice
,
DistilBertForQuestionAnswering
,
DistilBertForSequenceClassification
,
DistilBertForTokenClassification
,
...
...
@@ -218,6 +220,25 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
)
self
.
check_loss_output
(
result
)
def
create_and_check_distilbert_for_multiple_choice
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
config
.
num_choices
=
self
.
num_choices
model
=
DistilBertForMultipleChoice
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
loss
,
logits
=
model
(
multiple_choice_inputs_ids
,
attention_mask
=
multiple_choice_input_mask
,
labels
=
choice_labels
,
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
self
.
check_loss_output
(
result
)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
...
...
@@ -251,6 +272,10 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_distilbert_for_token_classification
(
*
config_and_inputs
)
def
test_for_multiple_choice
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_distilbert_for_multiple_choice
(
*
config_and_inputs
)
# @slow
# def test_model_from_pretrained(self):
# for model_name in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment