Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a3274ac4
Commit
a3274ac4
authored
Jun 03, 2019
by
thomwolf
Browse files
adding attention outputs in bert
parent
82649658
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
68 additions
and
9 deletions
+68
-9
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+34
-9
tests/modeling_gpt2_test.py
tests/modeling_gpt2_test.py
+34
-0
No files found.
pytorch_pretrained_bert/modeling.py
View file @
a3274ac4
...
@@ -275,12 +275,13 @@ class BertEmbeddings(nn.Module):
...
@@ -275,12 +275,13 @@ class BertEmbeddings(nn.Module):
class
BertSelfAttention
(
nn
.
Module
):
class
BertSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertSelfAttention
,
self
).
__init__
()
super
(
BertSelfAttention
,
self
).
__init__
()
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
:
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
:
raise
ValueError
(
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number of attention "
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
config
.
hidden_size
,
config
.
num_attention_heads
))
"heads (%d)"
%
(
config
.
hidden_size
,
config
.
num_attention_heads
))
self
.
output_attentions
=
output_attentions
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
attention_head_size
=
int
(
config
.
hidden_size
/
config
.
num_attention_heads
)
self
.
attention_head_size
=
int
(
config
.
hidden_size
/
config
.
num_attention_heads
)
self
.
all_head_size
=
self
.
num_attention_heads
*
self
.
attention_head_size
self
.
all_head_size
=
self
.
num_attention_heads
*
self
.
attention_head_size
...
@@ -322,6 +323,8 @@ class BertSelfAttention(nn.Module):
...
@@ -322,6 +323,8 @@ class BertSelfAttention(nn.Module):
context_layer
=
context_layer
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
context_layer
=
context_layer
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,)
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
if
self
.
output_attentions
:
return
attention_probs
,
context_layer
return
context_layer
return
context_layer
...
@@ -340,14 +343,19 @@ class BertSelfOutput(nn.Module):
...
@@ -340,14 +343,19 @@ class BertSelfOutput(nn.Module):
class
BertAttention
(
nn
.
Module
):
class
BertAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertAttention
,
self
).
__init__
()
super
(
BertAttention
,
self
).
__init__
()
self
.
self
=
BertSelfAttention
(
config
)
self
.
output_attentions
=
output_attentions
self
.
self
=
BertSelfAttention
(
config
,
output_attentions
=
output_attentions
)
self
.
output
=
BertSelfOutput
(
config
)
self
.
output
=
BertSelfOutput
(
config
)
def
forward
(
self
,
input_tensor
,
attention_mask
):
def
forward
(
self
,
input_tensor
,
attention_mask
):
self_output
=
self
.
self
(
input_tensor
,
attention_mask
)
self_output
=
self
.
self
(
input_tensor
,
attention_mask
)
if
self
.
output_attentions
:
attentions
,
self_output
=
self_output
attention_output
=
self
.
output
(
self_output
,
input_tensor
)
attention_output
=
self
.
output
(
self_output
,
input_tensor
)
if
self
.
output_attentions
:
return
attentions
,
attention_output
return
attention_output
return
attention_output
...
@@ -381,33 +389,45 @@ class BertOutput(nn.Module):
...
@@ -381,33 +389,45 @@ class BertOutput(nn.Module):
class
BertLayer
(
nn
.
Module
):
class
BertLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertLayer
,
self
).
__init__
()
super
(
BertLayer
,
self
).
__init__
()
self
.
attention
=
BertAttention
(
config
)
self
.
output_attentions
=
output_attentions
self
.
attention
=
BertAttention
(
config
,
output_attentions
=
output_attentions
)
self
.
intermediate
=
BertIntermediate
(
config
)
self
.
intermediate
=
BertIntermediate
(
config
)
self
.
output
=
BertOutput
(
config
)
self
.
output
=
BertOutput
(
config
)
def
forward
(
self
,
hidden_states
,
attention_mask
):
def
forward
(
self
,
hidden_states
,
attention_mask
):
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
)
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
)
if
self
.
output_attentions
:
attentions
,
attention_output
=
attention_output
intermediate_output
=
self
.
intermediate
(
attention_output
)
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
if
self
.
output_attentions
:
return
attentions
,
layer_output
return
layer_output
return
layer_output
class
BertEncoder
(
nn
.
Module
):
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertEncoder
,
self
).
__init__
()
super
(
BertEncoder
,
self
).
__init__
()
layer
=
BertLayer
(
config
)
self
.
output_attentions
=
output_attentions
layer
=
BertLayer
(
config
,
output_attentions
=
output_attentions
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
hidden_states
,
attention_mask
,
output_all_encoded_layers
=
True
):
def
forward
(
self
,
hidden_states
,
attention_mask
,
output_all_encoded_layers
=
True
):
all_encoder_layers
=
[]
all_encoder_layers
=
[]
all_attentions
=
[]
for
layer_module
in
self
.
layer
:
for
layer_module
in
self
.
layer
:
hidden_states
=
layer_module
(
hidden_states
,
attention_mask
)
hidden_states
=
layer_module
(
hidden_states
,
attention_mask
)
if
self
.
output_attentions
:
attentions
,
hidden_states
=
hidden_states
all_attentions
.
append
(
attentions
)
if
output_all_encoded_layers
:
if
output_all_encoded_layers
:
all_encoder_layers
.
append
(
hidden_states
)
all_encoder_layers
.
append
(
hidden_states
)
if
not
output_all_encoded_layers
:
if
not
output_all_encoded_layers
:
all_encoder_layers
.
append
(
hidden_states
)
all_encoder_layers
.
append
(
hidden_states
)
if
self
.
output_attentions
:
return
all_attentions
,
all_encoder_layers
return
all_encoder_layers
return
all_encoder_layers
...
@@ -699,10 +719,11 @@ class BertModel(BertPreTrainedModel):
...
@@ -699,10 +719,11 @@ class BertModel(BertPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertModel
,
self
).
__init__
(
config
)
super
(
BertModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
encoder
=
BertEncoder
(
config
)
self
.
encoder
=
BertEncoder
(
config
,
output_attentions
=
output_attentions
)
self
.
pooler
=
BertPooler
(
config
)
self
.
pooler
=
BertPooler
(
config
)
self
.
apply
(
self
.
init_bert_weights
)
self
.
apply
(
self
.
init_bert_weights
)
...
@@ -731,10 +752,14 @@ class BertModel(BertPreTrainedModel):
...
@@ -731,10 +752,14 @@ class BertModel(BertPreTrainedModel):
encoded_layers
=
self
.
encoder
(
embedding_output
,
encoded_layers
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
,
extended_attention_mask
,
output_all_encoded_layers
=
output_all_encoded_layers
)
output_all_encoded_layers
=
output_all_encoded_layers
)
if
self
.
output_attentions
:
all_attentions
,
encoded_layers
=
encoded_layers
sequence_output
=
encoded_layers
[
-
1
]
sequence_output
=
encoded_layers
[
-
1
]
pooled_output
=
self
.
pooler
(
sequence_output
)
pooled_output
=
self
.
pooler
(
sequence_output
)
if
not
output_all_encoded_layers
:
if
not
output_all_encoded_layers
:
encoded_layers
=
encoded_layers
[
-
1
]
encoded_layers
=
encoded_layers
[
-
1
]
if
self
.
output_attentions
:
return
all_attentions
,
encoded_layers
,
pooled_output
return
encoded_layers
,
pooled_output
return
encoded_layers
,
pooled_output
...
...
tests/modeling_gpt2_test.py
View file @
a3274ac4
...
@@ -133,11 +133,28 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -133,11 +133,28 @@ class GPT2ModelTest(unittest.TestCase):
}
}
return
outputs
return
outputs
def
create_gpt2_lm_head_with_output_attention
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
GPT2LMHeadModel
(
config
,
output_attentions
=
True
)
model
.
eval
()
loss
=
model
(
input_ids
,
position_ids
,
token_type_ids
,
lm_labels
)
attentions
,
lm_logits
,
presents
=
model
(
input_ids
,
position_ids
,
token_type_ids
)
outputs
=
{
"loss"
:
loss
,
"lm_logits"
:
lm_logits
,
"presents"
:
presents
,
"attentions"
:
attentions
,
}
return
outputs
def
check_gpt2_lm_head_output
(
self
,
result
):
def
check_gpt2_lm_head_output
(
self
,
result
):
total_voc
=
self
.
n_special
+
self
.
vocab_size
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"presents"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
def
check_gpt2_lm_head_loss_output
(
self
,
result
):
def
check_gpt2_lm_head_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
...
@@ -160,6 +177,23 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -160,6 +177,23 @@ class GPT2ModelTest(unittest.TestCase):
}
}
return
outputs
return
outputs
def
create_gpt2_double_heads_with_output_attention
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
GPT2DoubleHeadsModel
(
config
,
output_attentions
=
True
)
model
.
eval
()
loss
=
model
(
input_ids
,
mc_token_ids
,
lm_labels
=
lm_labels
,
mc_labels
=
mc_labels
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
)
attentions
,
lm_logits
,
mc_logits
,
presents
=
model
(
input_ids
,
mc_token_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
outputs
=
{
"loss"
:
loss
,
"lm_logits"
:
lm_logits
,
"mc_logits"
:
mc_logits
,
"presents"
:
presents
,
"attentions"
:
attentions
,
}
return
outputs
def
check_gpt2_double_heads_output
(
self
,
result
):
def
check_gpt2_double_heads_output
(
self
,
result
):
total_voc
=
self
.
n_special
+
self
.
vocab_size
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment