Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
268d4f20
Commit
268d4f20
authored
Nov 08, 2019
by
thomwolf
Browse files
fix position biases + better tests
parent
b4fcd59a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
31 deletions
+42
-31
transformers/modeling_t5.py
transformers/modeling_t5.py
+7
-4
transformers/tests/modeling_t5_test.py
transformers/tests/modeling_t5_test.py
+35
-27
No files found.
transformers/modeling_t5.py
View file @
268d4f20
...
...
@@ -408,7 +408,7 @@ class T5Block(nn.Module):
position_bias
=
position_bias
,
head_mask
=
head_mask
)
hidden_states
=
self_attention_outputs
[
0
]
outputs
=
self_attention_outputs
[
1
:]
outputs
=
self_attention_outputs
[
1
:]
# Keep self-attention outputs and relative position weights
if
not
self
.
is_decoder
:
hidden_states
=
self
.
layer
[
1
](
hidden_states
)
...
...
@@ -419,11 +419,11 @@ class T5Block(nn.Module):
position_bias
=
encoder_decoder_position_bias
,
head_mask
=
head_mask
)
hidden_states
=
cross_attention_outputs
[
0
]
outputs
=
cross_attention_outputs
[
1
:]
+
outpu
ts
outputs
=
outputs
+
cross_attention_outputs
[
1
:]
# Keep cross-attention outputs and relative position weigh
ts
hidden_states
=
self
.
layer
[
2
](
hidden_states
)
outputs
=
(
hidden_states
,)
+
outputs
# add attentions if we output them
return
outputs
return
outputs
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
class
T5PreTrainedModel
(
PreTrainedModel
):
...
...
@@ -564,14 +564,17 @@ class T5Stack(T5PreTrainedModel):
encoder_attention_mask
=
encoder_extended_attention_mask
,
encoder_decoder_position_bias
=
encoder_decoder_position_bias
,
head_mask
=
head_mask
[
i
])
# layer_outputs is a tuple with:
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states
=
layer_outputs
[
0
]
if
i
==
0
:
# We share the position biases between the layers - the first layer store them
position_bias
=
layer_outputs
[
2
if
self
.
output_attentions
else
1
]
if
self
.
is_decoder
:
encoder_decoder_position_bias
=
layer_outputs
[
4
if
self
.
output_attentions
else
2
]
if
self
.
output_attentions
:
all_attentions
=
all_attentions
+
(
layer_outputs
[
1
],)
all_attentions
=
all_attentions
+
(
layer_outputs
[
1
],)
# We keep only self-attention weights for now
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
layer_output
=
self
.
dropout
(
hidden_states
)
...
...
transformers/tests/modeling_t5_test.py
View file @
268d4f20
...
...
@@ -45,9 +45,10 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
encoder_seq_length
=
7
,
decoder_seq_length
=
9
,
is_training
=
True
,
use_
input
_mask
=
True
,
use_
attention
_mask
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
n_positions
=
14
,
...
...
@@ -62,9 +63,10 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
encoder_seq_length
=
encoder_seq_length
self
.
decoder_seq_length
=
decoder_seq_length
self
.
is_training
=
is_training
self
.
use_
input
_mask
=
use_
input
_mask
self
.
use_
attention
_mask
=
use_
attention
_mask
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
n_positions
=
n_positions
...
...
@@ -78,15 +80,18 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
encoder_input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
encoder_seq_length
],
self
.
vocab_size
)
decoder_input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
decoder_seq_length
],
self
.
vocab_size
)
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
encoder_attention_mask
=
None
decoder_attention_mask
=
None
if
self
.
use_attention_mask
:
encoder_attention_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
encoder_seq_length
],
vocab_size
=
2
)
decoder_attention_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
decoder_seq_length
],
vocab_size
=
2
)
token
_labels
=
None
decoder_lm
_labels
=
None
if
self
.
use_labels
:
token
_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
decoder_lm
_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
decoder_
seq_length
],
self
.
vocab_size
)
config
=
T5Config
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
...
...
@@ -100,21 +105,22 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
dropout_rate
=
self
.
dropout_rate
,
initializer_factor
=
self
.
initializer_factor
)
return
(
config
,
input_ids
,
input_mask
,
token
_labels
)
return
(
config
,
encoder_
input_ids
,
decoder_input_ids
,
encoder_attention_mask
,
decoder_attention_mask
,
decoder_lm
_labels
)
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
def
create_and_check_t5_model
(
self
,
config
,
input_ids
,
input_mask
,
token
_labels
):
def
create_and_check_t5_model
(
self
,
config
,
encoder_
input_ids
,
decoder_input_ids
,
encoder_attention_mask
,
decoder_attention_mask
,
decoder_lm
_labels
):
model
=
T5Model
(
config
=
config
)
model
.
eval
()
encoder_output
,
decoder_output
=
model
(
encoder_input_ids
=
input_ids
,
decoder_input_ids
=
input_ids
,
decoder_attention_mask
=
input_mask
)
encoder_output
,
decoder_output
=
model
(
encoder_input_ids
=
input_ids
,
decoder_input_ids
=
input_ids
)
decoder_output
,
encoder_output
=
model
(
encoder_input_ids
=
encoder_input_ids
,
decoder_input_ids
=
decoder_input_ids
,
encoder_attention_mask
=
encoder_attention_mask
,
decoder_attention_mask
=
decoder_attention_mask
)
decoder_output
,
encoder_output
=
model
(
encoder_input_ids
=
encoder_input_ids
,
decoder_input_ids
=
decoder_input_ids
)
result
=
{
"encoder_output"
:
encoder_output
,
...
...
@@ -122,17 +128,17 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"encoder_output"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
[
self
.
batch_size
,
self
.
encoder_
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"decoder_output"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
[
self
.
batch_size
,
self
.
decoder_
seq_length
,
self
.
hidden_size
])
def
create_and_check_t5_with_lm_head
(
self
,
config
,
input_ids
,
input_mask
,
token
_labels
):
def
create_and_check_t5_with_lm_head
(
self
,
config
,
encoder_
input_ids
,
decoder_input_ids
,
encoder_attention_mask
,
decoder_attention_mask
,
decoder_lm
_labels
):
model
=
T5WithLMHeadModel
(
config
=
config
)
model
.
eval
()
outputs
=
model
(
encoder_input_ids
=
input_ids
,
decoder_input_ids
=
input_ids
,
decoder_attention_mask
=
input
_mask
,
decoder_lm_labels
=
token
_labels
)
outputs
=
model
(
encoder_input_ids
=
encoder_
input_ids
,
decoder_input_ids
=
decoder_
input_ids
,
decoder_attention_mask
=
decoder_attention
_mask
,
decoder_lm_labels
=
decoder_lm
_labels
)
loss
,
prediction_scores
=
outputs
[
0
],
outputs
[
1
]
result
=
{
"loss"
:
loss
,
...
...
@@ -140,15 +146,17 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
decoder_
seq_length
,
self
.
vocab_size
])
self
.
check_loss_output
(
result
)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
input_mask
,
token_labels
)
=
config_and_inputs
inputs_dict
=
{
'encoder_input_ids'
:
input_ids
,
'decoder_input_ids'
:
input_ids
,
'decoder_attention_mask'
:
input_mask
}
(
config
,
encoder_input_ids
,
decoder_input_ids
,
encoder_attention_mask
,
decoder_attention_mask
,
decoder_lm_labels
)
=
config_and_inputs
inputs_dict
=
{
'encoder_input_ids'
:
encoder_input_ids
,
'decoder_input_ids'
:
decoder_input_ids
,
'decoder_attention_mask'
:
decoder_attention_mask
,
'encoder_attention_mask'
:
encoder_attention_mask
}
return
config
,
inputs_dict
def
setUp
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment