Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
18c8cf00
"vscode:/vscode.git/clone" did not exist on "ecdf9b06bc03af272ceb8d6951e30e677fdfd35c"
Unverified
Commit
18c8cf00
authored
Nov 23, 2020
by
Yossi Synett
Committed by
GitHub
Nov 23, 2020
Browse files
Fix bug in x-attentions output for roberta and harden test to catch it (#8660)
parent
48cc2247
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
11 deletions
+13
-11
src/transformers/models/roberta/modeling_roberta.py
src/transformers/models/roberta/modeling_roberta.py
+1
-1
tests/test_modeling_encoder_decoder.py
tests/test_modeling_encoder_decoder.py
+12
-10
No files found.
src/transformers/models/roberta/modeling_roberta.py
View file @
18c8cf00
...
@@ -814,7 +814,7 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
...
@@ -814,7 +814,7 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
logits
=
prediction_scores
,
logits
=
prediction_scores
,
hidden_states
=
outputs
.
hidden_states
,
hidden_states
=
outputs
.
hidden_states
,
attentions
=
outputs
.
attentions
,
attentions
=
outputs
.
attentions
,
cross_attentions
=
outputs
.
attentions
,
cross_attentions
=
outputs
.
cross_
attentions
,
)
)
def
prepare_inputs_for_generation
(
self
,
input_ids
,
attention_mask
=
None
,
**
model_kwargs
):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
attention_mask
=
None
,
**
model_kwargs
):
...
...
tests/test_modeling_encoder_decoder.py
View file @
18c8cf00
...
@@ -300,6 +300,9 @@ class EncoderDecoderMixin:
...
@@ -300,6 +300,9 @@ class EncoderDecoderMixin:
labels
,
labels
,
**
kwargs
**
kwargs
):
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids
=
decoder_input_ids
[:,
:
-
1
]
decoder_attention_mask
=
decoder_attention_mask
[:,
:
-
1
]
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
enc_dec_model
=
EncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
enc_dec_model
=
EncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
enc_dec_model
.
to
(
torch_device
)
enc_dec_model
.
to
(
torch_device
)
...
@@ -314,9 +317,8 @@ class EncoderDecoderMixin:
...
@@ -314,9 +317,8 @@ class EncoderDecoderMixin:
encoder_attentions
=
outputs_encoder_decoder
[
"encoder_attentions"
]
encoder_attentions
=
outputs_encoder_decoder
[
"encoder_attentions"
]
self
.
assertEqual
(
len
(
encoder_attentions
),
config
.
num_hidden_layers
)
self
.
assertEqual
(
len
(
encoder_attentions
),
config
.
num_hidden_layers
)
self
.
assertListEqual
(
self
.
assertEqual
(
list
(
encoder_attentions
[
0
].
shape
[
-
3
:]),
encoder_attentions
[
0
].
shape
[
-
3
:],
(
config
.
num_attention_heads
,
input_ids
.
shape
[
-
1
],
input_ids
.
shape
[
-
1
])
[
config
.
num_attention_heads
,
input_ids
.
shape
[
-
1
],
input_ids
.
shape
[
-
1
]],
)
)
decoder_attentions
=
outputs_encoder_decoder
[
"decoder_attentions"
]
decoder_attentions
=
outputs_encoder_decoder
[
"decoder_attentions"
]
...
@@ -327,20 +329,20 @@ class EncoderDecoderMixin:
...
@@ -327,20 +329,20 @@ class EncoderDecoderMixin:
)
)
self
.
assertEqual
(
len
(
decoder_attentions
),
num_decoder_layers
)
self
.
assertEqual
(
len
(
decoder_attentions
),
num_decoder_layers
)
self
.
assert
List
Equal
(
self
.
assertEqual
(
list
(
decoder_attentions
[
0
].
shape
[
-
3
:]
)
,
decoder_attentions
[
0
].
shape
[
-
3
:],
[
decoder_config
.
num_attention_heads
,
decoder_input_ids
.
shape
[
-
1
],
decoder_input_ids
.
shape
[
-
1
]
]
,
(
decoder_config
.
num_attention_heads
,
decoder_input_ids
.
shape
[
-
1
],
decoder_input_ids
.
shape
[
-
1
]
)
,
)
)
cross_attentions
=
outputs_encoder_decoder
[
"cross_attentions"
]
cross_attentions
=
outputs_encoder_decoder
[
"cross_attentions"
]
self
.
assertEqual
(
len
(
cross_attentions
),
num_decoder_layers
)
self
.
assertEqual
(
len
(
cross_attentions
),
num_decoder_layers
)
cross_attention_input_seq_len
=
input_ids
.
shape
[
-
1
]
*
(
cross_attention_input_seq_len
=
decoder_
input_ids
.
shape
[
-
1
]
*
(
1
+
(
decoder_config
.
ngram
if
hasattr
(
decoder_config
,
"ngram"
)
else
0
)
1
+
(
decoder_config
.
ngram
if
hasattr
(
decoder_config
,
"ngram"
)
else
0
)
)
)
self
.
assert
List
Equal
(
self
.
assertEqual
(
list
(
cross_attentions
[
0
].
shape
[
-
3
:]
)
,
cross_attentions
[
0
].
shape
[
-
3
:],
[
decoder_config
.
num_attention_heads
,
cross_attention_input_seq_len
,
decoder_
input_ids
.
shape
[
-
1
]
]
,
(
decoder_config
.
num_attention_heads
,
cross_attention_input_seq_len
,
input_ids
.
shape
[
-
1
]
)
,
)
)
def
check_encoder_decoder_model_generate
(
self
,
input_ids
,
config
,
decoder_config
,
**
kwargs
):
def
check_encoder_decoder_model_generate
(
self
,
input_ids
,
config
,
decoder_config
,
**
kwargs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment