Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a3274ac4
Commit
a3274ac4
authored
Jun 03, 2019
by
thomwolf
Browse files
adding attention outputs in bert
parent
82649658
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
68 additions
and
9 deletions
+68
-9
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+34
-9
tests/modeling_gpt2_test.py
tests/modeling_gpt2_test.py
+34
-0
No files found.
pytorch_pretrained_bert/modeling.py
View file @
a3274ac4
...
@@ -275,12 +275,13 @@ class BertEmbeddings(nn.Module):
...
@@ -275,12 +275,13 @@ class BertEmbeddings(nn.Module):
class
BertSelfAttention
(
nn
.
Module
):
class
BertSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertSelfAttention
,
self
).
__init__
()
super
(
BertSelfAttention
,
self
).
__init__
()
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
:
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
:
raise
ValueError
(
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number of attention "
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
config
.
hidden_size
,
config
.
num_attention_heads
))
"heads (%d)"
%
(
config
.
hidden_size
,
config
.
num_attention_heads
))
self
.
output_attentions
=
output_attentions
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
attention_head_size
=
int
(
config
.
hidden_size
/
config
.
num_attention_heads
)
self
.
attention_head_size
=
int
(
config
.
hidden_size
/
config
.
num_attention_heads
)
self
.
all_head_size
=
self
.
num_attention_heads
*
self
.
attention_head_size
self
.
all_head_size
=
self
.
num_attention_heads
*
self
.
attention_head_size
...
@@ -322,6 +323,8 @@ class BertSelfAttention(nn.Module):
...
@@ -322,6 +323,8 @@ class BertSelfAttention(nn.Module):
context_layer
=
context_layer
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
context_layer
=
context_layer
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,)
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
if
self
.
output_attentions
:
return
attention_probs
,
context_layer
return
context_layer
return
context_layer
...
@@ -340,14 +343,19 @@ class BertSelfOutput(nn.Module):
...
@@ -340,14 +343,19 @@ class BertSelfOutput(nn.Module):
class
BertAttention
(
nn
.
Module
):
class
BertAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertAttention
,
self
).
__init__
()
super
(
BertAttention
,
self
).
__init__
()
self
.
self
=
BertSelfAttention
(
config
)
self
.
output_attentions
=
output_attentions
self
.
self
=
BertSelfAttention
(
config
,
output_attentions
=
output_attentions
)
self
.
output
=
BertSelfOutput
(
config
)
self
.
output
=
BertSelfOutput
(
config
)
def
forward
(
self
,
input_tensor
,
attention_mask
):
def
forward
(
self
,
input_tensor
,
attention_mask
):
self_output
=
self
.
self
(
input_tensor
,
attention_mask
)
self_output
=
self
.
self
(
input_tensor
,
attention_mask
)
if
self
.
output_attentions
:
attentions
,
self_output
=
self_output
attention_output
=
self
.
output
(
self_output
,
input_tensor
)
attention_output
=
self
.
output
(
self_output
,
input_tensor
)
if
self
.
output_attentions
:
return
attentions
,
attention_output
return
attention_output
return
attention_output
...
@@ -381,33 +389,45 @@ class BertOutput(nn.Module):
...
@@ -381,33 +389,45 @@ class BertOutput(nn.Module):
class
BertLayer
(
nn
.
Module
):
class
BertLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertLayer
,
self
).
__init__
()
super
(
BertLayer
,
self
).
__init__
()
self
.
attention
=
BertAttention
(
config
)
self
.
output_attentions
=
output_attentions
self
.
attention
=
BertAttention
(
config
,
output_attentions
=
output_attentions
)
self
.
intermediate
=
BertIntermediate
(
config
)
self
.
intermediate
=
BertIntermediate
(
config
)
self
.
output
=
BertOutput
(
config
)
self
.
output
=
BertOutput
(
config
)
def
forward
(
self
,
hidden_states
,
attention_mask
):
def
forward
(
self
,
hidden_states
,
attention_mask
):
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
)
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
)
if
self
.
output_attentions
:
attentions
,
attention_output
=
attention_output
intermediate_output
=
self
.
intermediate
(
attention_output
)
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
if
self
.
output_attentions
:
return
attentions
,
layer_output
return
layer_output
return
layer_output
class
BertEncoder
(
nn
.
Module
):
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertEncoder
,
self
).
__init__
()
super
(
BertEncoder
,
self
).
__init__
()
layer
=
BertLayer
(
config
)
self
.
output_attentions
=
output_attentions
layer
=
BertLayer
(
config
,
output_attentions
=
output_attentions
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
hidden_states
,
attention_mask
,
output_all_encoded_layers
=
True
):
def
forward
(
self
,
hidden_states
,
attention_mask
,
output_all_encoded_layers
=
True
):
all_encoder_layers
=
[]
all_encoder_layers
=
[]
all_attentions
=
[]
for
layer_module
in
self
.
layer
:
for
layer_module
in
self
.
layer
:
hidden_states
=
layer_module
(
hidden_states
,
attention_mask
)
hidden_states
=
layer_module
(
hidden_states
,
attention_mask
)
if
self
.
output_attentions
:
attentions
,
hidden_states
=
hidden_states
all_attentions
.
append
(
attentions
)
if
output_all_encoded_layers
:
if
output_all_encoded_layers
:
all_encoder_layers
.
append
(
hidden_states
)
all_encoder_layers
.
append
(
hidden_states
)
if
not
output_all_encoded_layers
:
if
not
output_all_encoded_layers
:
all_encoder_layers
.
append
(
hidden_states
)
all_encoder_layers
.
append
(
hidden_states
)
if
self
.
output_attentions
:
return
all_attentions
,
all_encoder_layers
return
all_encoder_layers
return
all_encoder_layers
...
@@ -699,10 +719,11 @@ class BertModel(BertPreTrainedModel):
...
@@ -699,10 +719,11 @@ class BertModel(BertPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertModel
,
self
).
__init__
(
config
)
super
(
BertModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
encoder
=
BertEncoder
(
config
)
self
.
encoder
=
BertEncoder
(
config
,
output_attentions
=
output_attentions
)
self
.
pooler
=
BertPooler
(
config
)
self
.
pooler
=
BertPooler
(
config
)
self
.
apply
(
self
.
init_bert_weights
)
self
.
apply
(
self
.
init_bert_weights
)
...
@@ -731,10 +752,14 @@ class BertModel(BertPreTrainedModel):
...
@@ -731,10 +752,14 @@ class BertModel(BertPreTrainedModel):
encoded_layers
=
self
.
encoder
(
embedding_output
,
encoded_layers
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
,
extended_attention_mask
,
output_all_encoded_layers
=
output_all_encoded_layers
)
output_all_encoded_layers
=
output_all_encoded_layers
)
if
self
.
output_attentions
:
all_attentions
,
encoded_layers
=
encoded_layers
sequence_output
=
encoded_layers
[
-
1
]
sequence_output
=
encoded_layers
[
-
1
]
pooled_output
=
self
.
pooler
(
sequence_output
)
pooled_output
=
self
.
pooler
(
sequence_output
)
if
not
output_all_encoded_layers
:
if
not
output_all_encoded_layers
:
encoded_layers
=
encoded_layers
[
-
1
]
encoded_layers
=
encoded_layers
[
-
1
]
if
self
.
output_attentions
:
return
all_attentions
,
encoded_layers
,
pooled_output
return
encoded_layers
,
pooled_output
return
encoded_layers
,
pooled_output
...
...
tests/modeling_gpt2_test.py
View file @
a3274ac4
...
@@ -133,11 +133,28 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -133,11 +133,28 @@ class GPT2ModelTest(unittest.TestCase):
}
}
return
outputs
return
outputs
def
create_gpt2_lm_head_with_output_attention
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
GPT2LMHeadModel
(
config
,
output_attentions
=
True
)
model
.
eval
()
loss
=
model
(
input_ids
,
position_ids
,
token_type_ids
,
lm_labels
)
attentions
,
lm_logits
,
presents
=
model
(
input_ids
,
position_ids
,
token_type_ids
)
outputs
=
{
"loss"
:
loss
,
"lm_logits"
:
lm_logits
,
"presents"
:
presents
,
"attentions"
:
attentions
,
}
return
outputs
def
check_gpt2_lm_head_output
(
self
,
result
):
def
check_gpt2_lm_head_output
(
self
,
result
):
total_voc
=
self
.
n_special
+
self
.
vocab_size
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"presents"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
def
check_gpt2_lm_head_loss_output
(
self
,
result
):
def
check_gpt2_lm_head_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
...
@@ -160,6 +177,23 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -160,6 +177,23 @@ class GPT2ModelTest(unittest.TestCase):
}
}
return
outputs
return
outputs
def
create_gpt2_double_heads_with_output_attention
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
GPT2DoubleHeadsModel
(
config
,
output_attentions
=
True
)
model
.
eval
()
loss
=
model
(
input_ids
,
mc_token_ids
,
lm_labels
=
lm_labels
,
mc_labels
=
mc_labels
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
)
attentions
,
lm_logits
,
mc_logits
,
presents
=
model
(
input_ids
,
mc_token_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
outputs
=
{
"loss"
:
loss
,
"lm_logits"
:
lm_logits
,
"mc_logits"
:
mc_logits
,
"presents"
:
presents
,
"attentions"
:
attentions
,
}
return
outputs
def
check_gpt2_double_heads_output
(
self
,
result
):
def
check_gpt2_double_heads_output
(
self
,
result
):
total_voc
=
self
.
n_special
+
self
.
vocab_size
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment