Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
48a8f3da
Unverified
Commit
48a8f3da
authored
May 10, 2022
by
Jason Phang
Committed by
GitHub
May 10, 2022
Browse files
Add DebertaV2ForMultipleChoice (#17135)
parent
4ad2f68e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
144 additions
and
0 deletions
+144
-0
docs/source/en/model_doc/deberta-v2.mdx
docs/source/en/model_doc/deberta-v2.mdx
+5
-0
src/transformers/__init__.py
src/transformers/__init__.py
+2
-0
src/transformers/models/auto/modeling_auto.py
src/transformers/models/auto/modeling_auto.py
+1
-0
src/transformers/models/deberta_v2/__init__.py
src/transformers/models/deberta_v2/__init__.py
+2
-0
src/transformers/models/deberta_v2/modeling_deberta_v2.py
src/transformers/models/deberta_v2/modeling_deberta_v2.py
+104
-0
src/transformers/utils/dummy_pt_objects.py
src/transformers/utils/dummy_pt_objects.py
+7
-0
tests/models/deberta_v2/test_modeling_deberta_v2.py
tests/models/deberta_v2/test_modeling_deberta_v2.py
+23
-0
No files found.
docs/source/en/model_doc/deberta-v2.mdx
View file @
48a8f3da
...
@@ -107,6 +107,11 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code
...
@@ -107,6 +107,11 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code
[[autodoc]] DebertaV2ForQuestionAnswering
[[autodoc]] DebertaV2ForQuestionAnswering
- forward
- forward
## DebertaV2ForMultipleChoice
[[autodoc]] DebertaV2ForMultipleChoice
- forward
## TFDebertaV2Model
## TFDebertaV2Model
[[autodoc]] TFDebertaV2Model
[[autodoc]] TFDebertaV2Model
...
...
src/transformers/__init__.py
View file @
48a8f3da
...
@@ -948,6 +948,7 @@ else:
...
@@ -948,6 +948,7 @@ else:
[
[
"DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"DebertaV2ForMaskedLM"
,
"DebertaV2ForMaskedLM"
,
"DebertaV2ForMultipleChoice"
,
"DebertaV2ForQuestionAnswering"
,
"DebertaV2ForQuestionAnswering"
,
"DebertaV2ForSequenceClassification"
,
"DebertaV2ForSequenceClassification"
,
"DebertaV2ForTokenClassification"
,
"DebertaV2ForTokenClassification"
,
...
@@ -3296,6 +3297,7 @@ if TYPE_CHECKING:
...
@@ -3296,6 +3297,7 @@ if TYPE_CHECKING:
from
.models.deberta_v2
import
(
from
.models.deberta_v2
import
(
DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST
,
DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST
,
DebertaV2ForMaskedLM
,
DebertaV2ForMaskedLM
,
DebertaV2ForMultipleChoice
,
DebertaV2ForQuestionAnswering
,
DebertaV2ForQuestionAnswering
,
DebertaV2ForSequenceClassification
,
DebertaV2ForSequenceClassification
,
DebertaV2ForTokenClassification
,
DebertaV2ForTokenClassification
,
...
...
src/transformers/models/auto/modeling_auto.py
View file @
48a8f3da
...
@@ -597,6 +597,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
...
@@ -597,6 +597,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
(
"funnel"
,
"FunnelForMultipleChoice"
),
(
"funnel"
,
"FunnelForMultipleChoice"
),
(
"mpnet"
,
"MPNetForMultipleChoice"
),
(
"mpnet"
,
"MPNetForMultipleChoice"
),
(
"ibert"
,
"IBertForMultipleChoice"
),
(
"ibert"
,
"IBertForMultipleChoice"
),
(
"deberta-v2"
,
"DebertaV2ForMultipleChoice"
),
]
]
)
)
...
...
src/transformers/models/deberta_v2/__init__.py
View file @
48a8f3da
...
@@ -65,6 +65,7 @@ else:
...
@@ -65,6 +65,7 @@ else:
_import_structure
[
"modeling_deberta_v2"
]
=
[
_import_structure
[
"modeling_deberta_v2"
]
=
[
"DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"DebertaV2ForMaskedLM"
,
"DebertaV2ForMaskedLM"
,
"DebertaV2ForMultipleChoice"
,
"DebertaV2ForQuestionAnswering"
,
"DebertaV2ForQuestionAnswering"
,
"DebertaV2ForSequenceClassification"
,
"DebertaV2ForSequenceClassification"
,
"DebertaV2ForTokenClassification"
,
"DebertaV2ForTokenClassification"
,
...
@@ -110,6 +111,7 @@ if TYPE_CHECKING:
...
@@ -110,6 +111,7 @@ if TYPE_CHECKING:
from
.modeling_deberta_v2
import
(
from
.modeling_deberta_v2
import
(
DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST
,
DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST
,
DebertaV2ForMaskedLM
,
DebertaV2ForMaskedLM
,
DebertaV2ForMultipleChoice
,
DebertaV2ForQuestionAnswering
,
DebertaV2ForQuestionAnswering
,
DebertaV2ForSequenceClassification
,
DebertaV2ForSequenceClassification
,
DebertaV2ForTokenClassification
,
DebertaV2ForTokenClassification
,
...
...
src/transformers/models/deberta_v2/modeling_deberta_v2.py
View file @
48a8f3da
...
@@ -27,6 +27,7 @@ from ...activations import ACT2FN
...
@@ -27,6 +27,7 @@ from ...activations import ACT2FN
from
...modeling_outputs
import
(
from
...modeling_outputs
import
(
BaseModelOutput
,
BaseModelOutput
,
MaskedLMOutput
,
MaskedLMOutput
,
MultipleChoiceModelOutput
,
QuestionAnsweringModelOutput
,
QuestionAnsweringModelOutput
,
SequenceClassifierOutput
,
SequenceClassifierOutput
,
TokenClassifierOutput
,
TokenClassifierOutput
,
...
@@ -1511,3 +1512,106 @@ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
...
@@ -1511,3 +1512,106 @@ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
hidden_states
=
outputs
.
hidden_states
,
hidden_states
=
outputs
.
hidden_states
,
attentions
=
outputs
.
attentions
,
attentions
=
outputs
.
attentions
,
)
)
@
add_start_docstrings
(
"""
DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
"""
,
DEBERTA_START_DOCSTRING
,
)
class
DebertaV2ForMultipleChoice
(
DebertaV2PreTrainedModel
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
num_labels
=
getattr
(
config
,
"num_labels"
,
2
)
self
.
num_labels
=
num_labels
self
.
deberta
=
DebertaV2Model
(
config
)
self
.
pooler
=
ContextPooler
(
config
)
output_dim
=
self
.
pooler
.
output_dim
self
.
classifier
=
nn
.
Linear
(
output_dim
,
1
)
drop_out
=
getattr
(
config
,
"cls_dropout"
,
None
)
drop_out
=
self
.
config
.
hidden_dropout_prob
if
drop_out
is
None
else
drop_out
self
.
dropout
=
StableDropout
(
drop_out
)
self
.
init_weights
()
def
get_input_embeddings
(
self
):
return
self
.
deberta
.
get_input_embeddings
()
def
set_input_embeddings
(
self
,
new_embeddings
):
self
.
deberta
.
set_input_embeddings
(
new_embeddings
)
@
add_start_docstrings_to_model_forward
(
DEBERTA_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
checkpoint
=
_CHECKPOINT_FOR_DOC
,
output_type
=
MultipleChoiceModelOutput
,
config_class
=
_CONFIG_FOR_DOC
,
)
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
return_dict
=
None
,
):
r
"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
`input_ids` above)
"""
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
num_choices
=
input_ids
.
shape
[
1
]
if
input_ids
is
not
None
else
inputs_embeds
.
shape
[
1
]
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
if
input_ids
is
not
None
else
None
flat_position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
if
position_ids
is
not
None
else
None
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
flat_inputs_embeds
=
(
inputs_embeds
.
view
(
-
1
,
inputs_embeds
.
size
(
-
2
),
inputs_embeds
.
size
(
-
1
))
if
inputs_embeds
is
not
None
else
None
)
outputs
=
self
.
deberta
(
flat_input_ids
,
position_ids
=
flat_position_ids
,
token_type_ids
=
flat_token_type_ids
,
attention_mask
=
flat_attention_mask
,
inputs_embeds
=
flat_inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
encoder_layer
=
outputs
[
0
]
pooled_output
=
self
.
pooler
(
encoder_layer
)
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
reshaped_logits
=
logits
.
view
(
-
1
,
num_choices
)
loss
=
None
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
reshaped_logits
,
labels
)
if
not
return_dict
:
output
=
(
reshaped_logits
,)
+
outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
MultipleChoiceModelOutput
(
loss
=
loss
,
logits
=
reshaped_logits
,
hidden_states
=
outputs
.
hidden_states
,
attentions
=
outputs
.
attentions
,
)
src/transformers/utils/dummy_pt_objects.py
View file @
48a8f3da
...
@@ -1406,6 +1406,13 @@ class DebertaV2ForMaskedLM(metaclass=DummyObject):
...
@@ -1406,6 +1406,13 @@ class DebertaV2ForMaskedLM(metaclass=DummyObject):
requires_backends
(
self
,
[
"torch"
])
requires_backends
(
self
,
[
"torch"
])
class
DebertaV2ForMultipleChoice
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
DebertaV2ForQuestionAnswering
(
metaclass
=
DummyObject
):
class
DebertaV2ForQuestionAnswering
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
_backends
=
[
"torch"
]
...
...
tests/models/deberta_v2/test_modeling_deberta_v2.py
View file @
48a8f3da
...
@@ -26,6 +26,7 @@ if is_torch_available():
...
@@ -26,6 +26,7 @@ if is_torch_available():
from
transformers
import
(
from
transformers
import
(
DebertaV2ForMaskedLM
,
DebertaV2ForMaskedLM
,
DebertaV2ForMultipleChoice
,
DebertaV2ForQuestionAnswering
,
DebertaV2ForQuestionAnswering
,
DebertaV2ForSequenceClassification
,
DebertaV2ForSequenceClassification
,
DebertaV2ForTokenClassification
,
DebertaV2ForTokenClassification
,
...
@@ -192,6 +193,23 @@ class DebertaV2ModelTester(object):
...
@@ -192,6 +193,23 @@ class DebertaV2ModelTester(object):
self
.
parent
.
assertEqual
(
result
.
start_logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
))
self
.
parent
.
assertEqual
(
result
.
start_logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
))
self
.
parent
.
assertEqual
(
result
.
end_logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
))
self
.
parent
.
assertEqual
(
result
.
end_logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
))
def
create_and_check_deberta_for_multiple_choice
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
DebertaV2ForMultipleChoice
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
result
=
model
(
multiple_choice_inputs_ids
,
attention_mask
=
multiple_choice_input_mask
,
token_type_ids
=
multiple_choice_token_type_ids
,
labels
=
choice_labels
,
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
num_choices
))
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
(
...
@@ -217,6 +235,7 @@ class DebertaV2ModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -217,6 +235,7 @@ class DebertaV2ModelTest(ModelTesterMixin, unittest.TestCase):
DebertaV2ForSequenceClassification
,
DebertaV2ForSequenceClassification
,
DebertaV2ForTokenClassification
,
DebertaV2ForTokenClassification
,
DebertaV2ForQuestionAnswering
,
DebertaV2ForQuestionAnswering
,
DebertaV2ForMultipleChoice
,
)
)
if
is_torch_available
()
if
is_torch_available
()
else
()
else
()
...
@@ -254,6 +273,10 @@ class DebertaV2ModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -254,6 +273,10 @@ class DebertaV2ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_deberta_for_token_classification
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_deberta_for_token_classification
(
*
config_and_inputs
)
def
test_for_multiple_choice
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_deberta_for_multiple_choice
(
*
config_and_inputs
)
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
for
model_name
in
DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
for
model_name
in
DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment