Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d541938c
Unverified
Commit
d541938c
authored
Jun 10, 2020
by
Sylvain Gugger
Committed by
GitHub
Jun 10, 2020
Browse files
Make multiple choice models work with input_embeds (#4921)
parent
1e2631d6
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
46 additions
and
23 deletions
+46
-23
src/transformers/modeling_bert.py
src/transformers/modeling_bert.py
+7
-2
src/transformers/modeling_longformer.py
src/transformers/modeling_longformer.py
+8
-2
src/transformers/modeling_roberta.py
src/transformers/modeling_roberta.py
+9
-2
src/transformers/modeling_xlnet.py
src/transformers/modeling_xlnet.py
+8
-3
tests/test_modeling_common.py
tests/test_modeling_common.py
+14
-14
No files found.
src/transformers/modeling_bert.py
View file @
d541938c
...
...
@@ -1359,12 +1359,17 @@ class BertForMultipleChoice(BertPreTrainedModel):
# the linear classifier still needs to be trained
loss, logits = outputs[:2]
"""
num_choices
=
input_ids
.
shape
[
1
]
num_choices
=
input_ids
.
shape
[
1
]
if
input_ids
is
not
None
else
inputs_embeds
.
shape
[
1
]
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
if
input_ids
is
not
None
else
None
attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
if
position_ids
is
not
None
else
None
inputs_embeds
=
(
inputs_embeds
.
view
(
-
1
,
inputs_embeds
.
size
(
-
2
),
inputs_embeds
.
size
(
-
1
))
if
inputs_embeds
is
not
None
else
None
)
outputs
=
self
.
bert
(
input_ids
,
...
...
src/transformers/modeling_longformer.py
View file @
d541938c
...
...
@@ -1202,7 +1202,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
loss, classification_scores = outputs[:2]
"""
num_choices
=
input_ids
.
shape
[
1
]
num_choices
=
input_ids
.
shape
[
1
]
if
input_ids
is
not
None
else
inputs_embeds
.
shape
[
1
]
# set global attention on question tokens
if
global_attention_mask
is
None
:
...
...
@@ -1216,7 +1216,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
dim
=
1
,
)
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
if
input_ids
is
not
None
else
None
flat_position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
if
position_ids
is
not
None
else
None
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
...
...
@@ -1225,6 +1225,11 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
if
global_attention_mask
is
not
None
else
None
)
flat_inputs_embeds
=
(
inputs_embeds
.
view
(
-
1
,
inputs_embeds
.
size
(
-
2
),
inputs_embeds
.
size
(
-
1
))
if
inputs_embeds
is
not
None
else
None
)
outputs
=
self
.
longformer
(
flat_input_ids
,
...
...
@@ -1232,6 +1237,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
token_type_ids
=
flat_token_type_ids
,
attention_mask
=
flat_attention_mask
,
global_attention_mask
=
flat_global_attention_mask
,
inputs_embeds
=
flat_inputs_embeds
,
output_attentions
=
output_attentions
,
)
pooled_output
=
outputs
[
1
]
...
...
src/transformers/modeling_roberta.py
View file @
d541938c
...
...
@@ -448,18 +448,25 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
loss, classification_scores = outputs[:2]
"""
num_choices
=
input_ids
.
shape
[
1
]
num_choices
=
input_ids
.
shape
[
1
]
if
input_ids
is
not
None
else
inputs_embeds
.
shape
[
1
]
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
if
input_ids
is
not
None
else
None
flat_position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
if
position_ids
is
not
None
else
None
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
flat_inputs_embeds
=
(
inputs_embeds
.
view
(
-
1
,
inputs_embeds
.
size
(
-
2
),
inputs_embeds
.
size
(
-
1
))
if
inputs_embeds
is
not
None
else
None
)
outputs
=
self
.
roberta
(
flat_input_ids
,
position_ids
=
flat_position_ids
,
token_type_ids
=
flat_token_type_ids
,
attention_mask
=
flat_attention_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
flat_inputs_embeds
,
output_attentions
=
output_attentions
,
)
pooled_output
=
outputs
[
1
]
...
...
src/transformers/modeling_xlnet.py
View file @
d541938c
...
...
@@ -1438,12 +1438,17 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
loss, classification_scores = outputs[:2]
"""
num_choices
=
input_ids
.
shape
[
1
]
num_choices
=
input_ids
.
shape
[
1
]
if
input_ids
is
not
None
else
inputs_embeds
.
shape
[
1
]
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
if
input_ids
is
not
None
else
None
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
flat_input_mask
=
input_mask
.
view
(
-
1
,
input_mask
.
size
(
-
1
))
if
input_mask
is
not
None
else
None
flat_inputs_embeds
=
(
inputs_embeds
.
view
(
-
1
,
inputs_embeds
.
size
(
-
2
),
inputs_embeds
.
size
(
-
1
))
if
inputs_embeds
is
not
None
else
None
)
transformer_outputs
=
self
.
transformer
(
flat_input_ids
,
...
...
@@ -1454,7 +1459,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
flat_
inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
)
...
...
tests/test_modeling_common.py
View file @
d541938c
...
...
@@ -639,31 +639,31 @@ class ModelTesterMixin:
def
test_inputs_embeds
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
if
not
self
.
is_encoder_decoder
:
input_ids
=
inputs_dict
[
"input_ids"
]
del
inputs_dict
[
"input_ids"
]
else
:
encoder_input_ids
=
inputs_dict
[
"input_ids"
]
decoder_input_ids
=
inputs_dict
.
get
(
"decoder_input_ids"
,
encoder_input_ids
)
del
inputs_dict
[
"input_ids"
]
inputs_dict
.
pop
(
"decoder_input_ids"
,
None
)
for
model_class
in
self
.
all_model_classes
:
if
model_class
in
MODEL_FOR_MULTIPLE_CHOICE_MAPPING
.
values
():
continue
model
=
model_class
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
inputs
=
copy
.
deepcopy
(
self
.
_prepare_for_class
(
inputs_dict
,
model_class
))
if
not
self
.
is_encoder_decoder
:
input_ids
=
inputs
[
"input_ids"
]
del
inputs
[
"input_ids"
]
else
:
encoder_input_ids
=
inputs
[
"input_ids"
]
decoder_input_ids
=
inputs
.
get
(
"decoder_input_ids"
,
encoder_input_ids
)
del
inputs
[
"input_ids"
]
inputs
.
pop
(
"decoder_input_ids"
,
None
)
wte
=
model
.
get_input_embeddings
()
if
not
self
.
is_encoder_decoder
:
inputs
_dict
[
"inputs_embeds"
]
=
wte
(
input_ids
)
inputs
[
"inputs_embeds"
]
=
wte
(
input_ids
)
else
:
inputs
_dict
[
"inputs_embeds"
]
=
wte
(
encoder_input_ids
)
inputs
_dict
[
"decoder_inputs_embeds"
]
=
wte
(
decoder_input_ids
)
inputs
[
"inputs_embeds"
]
=
wte
(
encoder_input_ids
)
inputs
[
"decoder_inputs_embeds"
]
=
wte
(
decoder_input_ids
)
with
torch
.
no_grad
():
model
(
**
inputs
_dict
)
model
(
**
inputs
)
def
test_lm_head_model_random_no_beam_search_generate
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment