Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d541938c
"vscode:/vscode.git/clone" did not exist on "742dd3aa322fe7f78a820af20ecfbce1b8081aa9"
Unverified
Commit
d541938c
authored
Jun 10, 2020
by
Sylvain Gugger
Committed by
GitHub
Jun 10, 2020
Browse files
Make multiple choice models work with input_embeds (#4921)
parent
1e2631d6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
46 additions
and
23 deletions
+46
-23
src/transformers/modeling_bert.py
src/transformers/modeling_bert.py
+7
-2
src/transformers/modeling_longformer.py
src/transformers/modeling_longformer.py
+8
-2
src/transformers/modeling_roberta.py
src/transformers/modeling_roberta.py
+9
-2
src/transformers/modeling_xlnet.py
src/transformers/modeling_xlnet.py
+8
-3
tests/test_modeling_common.py
tests/test_modeling_common.py
+14
-14
No files found.
src/transformers/modeling_bert.py
View file @
d541938c
...
...
@@ -1359,12 +1359,17 @@ class BertForMultipleChoice(BertPreTrainedModel):
# the linear classifier still needs to be trained
loss, logits = outputs[:2]
"""
num_choices
=
input_ids
.
shape
[
1
]
num_choices
=
input_ids
.
shape
[
1
]
if
input_ids
is
not
None
else
inputs_embeds
.
shape
[
1
]
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
if
input_ids
is
not
None
else
None
attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
if
position_ids
is
not
None
else
None
inputs_embeds
=
(
inputs_embeds
.
view
(
-
1
,
inputs_embeds
.
size
(
-
2
),
inputs_embeds
.
size
(
-
1
))
if
inputs_embeds
is
not
None
else
None
)
outputs
=
self
.
bert
(
input_ids
,
...
...
src/transformers/modeling_longformer.py
View file @
d541938c
...
...
@@ -1202,7 +1202,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
loss, classification_scores = outputs[:2]
"""
num_choices
=
input_ids
.
shape
[
1
]
num_choices
=
input_ids
.
shape
[
1
]
if
input_ids
is
not
None
else
inputs_embeds
.
shape
[
1
]
# set global attention on question tokens
if
global_attention_mask
is
None
:
...
...
@@ -1216,7 +1216,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
dim
=
1
,
)
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
if
input_ids
is
not
None
else
None
flat_position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
if
position_ids
is
not
None
else
None
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
...
...
@@ -1225,6 +1225,11 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
if
global_attention_mask
is
not
None
else
None
)
flat_inputs_embeds
=
(
inputs_embeds
.
view
(
-
1
,
inputs_embeds
.
size
(
-
2
),
inputs_embeds
.
size
(
-
1
))
if
inputs_embeds
is
not
None
else
None
)
outputs
=
self
.
longformer
(
flat_input_ids
,
...
...
@@ -1232,6 +1237,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
token_type_ids
=
flat_token_type_ids
,
attention_mask
=
flat_attention_mask
,
global_attention_mask
=
flat_global_attention_mask
,
inputs_embeds
=
flat_inputs_embeds
,
output_attentions
=
output_attentions
,
)
pooled_output
=
outputs
[
1
]
...
...
src/transformers/modeling_roberta.py
View file @
d541938c
...
...
@@ -448,18 +448,25 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
loss, classification_scores = outputs[:2]
"""
num_choices
=
input_ids
.
shape
[
1
]
num_choices
=
input_ids
.
shape
[
1
]
if
input_ids
is
not
None
else
inputs_embeds
.
shape
[
1
]
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
if
input_ids
is
not
None
else
None
flat_position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
if
position_ids
is
not
None
else
None
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
flat_inputs_embeds
=
(
inputs_embeds
.
view
(
-
1
,
inputs_embeds
.
size
(
-
2
),
inputs_embeds
.
size
(
-
1
))
if
inputs_embeds
is
not
None
else
None
)
outputs
=
self
.
roberta
(
flat_input_ids
,
position_ids
=
flat_position_ids
,
token_type_ids
=
flat_token_type_ids
,
attention_mask
=
flat_attention_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
flat_inputs_embeds
,
output_attentions
=
output_attentions
,
)
pooled_output
=
outputs
[
1
]
...
...
src/transformers/modeling_xlnet.py
View file @
d541938c
...
...
@@ -1438,12 +1438,17 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
loss, classification_scores = outputs[:2]
"""
num_choices
=
input_ids
.
shape
[
1
]
num_choices
=
input_ids
.
shape
[
1
]
if
input_ids
is
not
None
else
inputs_embeds
.
shape
[
1
]
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
if
input_ids
is
not
None
else
None
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
flat_input_mask
=
input_mask
.
view
(
-
1
,
input_mask
.
size
(
-
1
))
if
input_mask
is
not
None
else
None
flat_inputs_embeds
=
(
inputs_embeds
.
view
(
-
1
,
inputs_embeds
.
size
(
-
2
),
inputs_embeds
.
size
(
-
1
))
if
inputs_embeds
is
not
None
else
None
)
transformer_outputs
=
self
.
transformer
(
flat_input_ids
,
...
...
@@ -1454,7 +1459,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
flat_
inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
)
...
...
tests/test_modeling_common.py
View file @
d541938c
...
...
@@ -639,31 +639,31 @@ class ModelTesterMixin:
def
test_inputs_embeds
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
if
not
self
.
is_encoder_decoder
:
input_ids
=
inputs_dict
[
"input_ids"
]
del
inputs_dict
[
"input_ids"
]
else
:
encoder_input_ids
=
inputs_dict
[
"input_ids"
]
decoder_input_ids
=
inputs_dict
.
get
(
"decoder_input_ids"
,
encoder_input_ids
)
del
inputs_dict
[
"input_ids"
]
inputs_dict
.
pop
(
"decoder_input_ids"
,
None
)
for
model_class
in
self
.
all_model_classes
:
if
model_class
in
MODEL_FOR_MULTIPLE_CHOICE_MAPPING
.
values
():
continue
model
=
model_class
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
inputs
=
copy
.
deepcopy
(
self
.
_prepare_for_class
(
inputs_dict
,
model_class
))
if
not
self
.
is_encoder_decoder
:
input_ids
=
inputs
[
"input_ids"
]
del
inputs
[
"input_ids"
]
else
:
encoder_input_ids
=
inputs
[
"input_ids"
]
decoder_input_ids
=
inputs
.
get
(
"decoder_input_ids"
,
encoder_input_ids
)
del
inputs
[
"input_ids"
]
inputs
.
pop
(
"decoder_input_ids"
,
None
)
wte
=
model
.
get_input_embeddings
()
if
not
self
.
is_encoder_decoder
:
inputs
_dict
[
"inputs_embeds"
]
=
wte
(
input_ids
)
inputs
[
"inputs_embeds"
]
=
wte
(
input_ids
)
else
:
inputs
_dict
[
"inputs_embeds"
]
=
wte
(
encoder_input_ids
)
inputs
_dict
[
"decoder_inputs_embeds"
]
=
wte
(
decoder_input_ids
)
inputs
[
"inputs_embeds"
]
=
wte
(
encoder_input_ids
)
inputs
[
"decoder_inputs_embeds"
]
=
wte
(
decoder_input_ids
)
with
torch
.
no_grad
():
model
(
**
inputs
_dict
)
model
(
**
inputs
)
def
test_lm_head_model_random_no_beam_search_generate
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment