Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
de5e5682
"...resnet50_tensorflow.git" did not exist on "44fa1d377c81371a85256db57563d3e2016c7730"
Commit
de5e5682
authored
Jun 03, 2019
by
VictorSanh
Browse files
add output_attentions for BertModel
parent
275179a0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
7 deletions
+27
-7
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+27
-7
No files found.
pytorch_pretrained_bert/modeling.py
View file @
de5e5682
...
@@ -275,7 +275,7 @@ class BertEmbeddings(nn.Module):
...
@@ -275,7 +275,7 @@ class BertEmbeddings(nn.Module):
class
BertSelfAttention
(
nn
.
Module
):
class
BertSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertSelfAttention
,
self
).
__init__
()
super
(
BertSelfAttention
,
self
).
__init__
()
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
:
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
:
raise
ValueError
(
raise
ValueError
(
...
@@ -291,6 +291,8 @@ class BertSelfAttention(nn.Module):
...
@@ -291,6 +291,8 @@ class BertSelfAttention(nn.Module):
self
.
dropout
=
nn
.
Dropout
(
config
.
attention_probs_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
attention_probs_dropout_prob
)
self
.
output_attentions
=
output_attentions
def
transpose_for_scores
(
self
,
x
):
def
transpose_for_scores
(
self
,
x
):
new_x_shape
=
x
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads
,
self
.
attention_head_size
)
new_x_shape
=
x
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads
,
self
.
attention_head_size
)
x
=
x
.
view
(
*
new_x_shape
)
x
=
x
.
view
(
*
new_x_shape
)
...
@@ -322,7 +324,10 @@ class BertSelfAttention(nn.Module):
...
@@ -322,7 +324,10 @@ class BertSelfAttention(nn.Module):
context_layer
=
context_layer
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
context_layer
=
context_layer
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,)
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
return
context_layer
if
self
.
output_attentions
:
return
attention_probs
,
context_layer
else
:
return
context_layer
class
BertSelfOutput
(
nn
.
Module
):
class
BertSelfOutput
(
nn
.
Module
):
...
@@ -381,33 +386,43 @@ class BertOutput(nn.Module):
...
@@ -381,33 +386,43 @@ class BertOutput(nn.Module):
class
BertLayer
(
nn
.
Module
):
class
BertLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertLayer
,
self
).
__init__
()
super
(
BertLayer
,
self
).
__init__
()
self
.
attention
=
BertAttention
(
config
)
self
.
attention
=
BertAttention
(
config
)
self
.
intermediate
=
BertIntermediate
(
config
)
self
.
intermediate
=
BertIntermediate
(
config
)
self
.
output
=
BertOutput
(
config
)
self
.
output
=
BertOutput
(
config
)
self
.
output_attentions
=
output_attentions
def
forward
(
self
,
hidden_states
,
attention_mask
):
def
forward
(
self
,
hidden_states
,
attention_mask
):
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
)
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
)
intermediate_output
=
self
.
intermediate
(
attention_output
)
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
if
self
.
output_attentions
:
return
attention_output
,
layer_output
return
layer_output
return
layer_output
class
BertEncoder
(
nn
.
Module
):
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertEncoder
,
self
).
__init__
()
super
(
BertEncoder
,
self
).
__init__
()
layer
=
BertLayer
(
config
)
layer
=
BertLayer
(
config
,
output_attentions
=
output_attentions
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
output_attentions
=
output_attentions
def
forward
(
self
,
hidden_states
,
attention_mask
,
output_all_encoded_layers
=
True
):
def
forward
(
self
,
hidden_states
,
attention_mask
,
output_all_encoded_layers
=
True
):
all_encoder_layers
=
[]
all_encoder_layers
=
[]
all_attentions
=
[]
for
layer_module
in
self
.
layer
:
for
layer_module
in
self
.
layer
:
hidden_states
=
layer_module
(
hidden_states
,
attention_mask
)
hidden_states
=
layer_module
(
hidden_states
,
attention_mask
)
if
self
.
output_attentions
:
attentions
,
hidden_states
=
hidden_states
all_attentions
.
append
(
attentions
)
if
output_all_encoded_layers
:
if
output_all_encoded_layers
:
all_encoder_layers
.
append
(
hidden_states
)
all_encoder_layers
.
append
(
hidden_states
)
if
not
output_all_encoded_layers
:
if
not
output_all_encoded_layers
:
all_encoder_layers
.
append
(
hidden_states
)
all_encoder_layers
.
append
(
hidden_states
)
if
self
.
output_attentions
:
return
all_attentions
,
all_encoder_layers
return
all_encoder_layers
return
all_encoder_layers
...
@@ -699,12 +714,13 @@ class BertModel(BertPreTrainedModel):
...
@@ -699,12 +714,13 @@ class BertModel(BertPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertModel
,
self
).
__init__
(
config
)
super
(
BertModel
,
self
).
__init__
(
config
)
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
encoder
=
BertEncoder
(
config
)
self
.
encoder
=
BertEncoder
(
config
,
output_attentions
=
output_attentions
)
self
.
pooler
=
BertPooler
(
config
)
self
.
pooler
=
BertPooler
(
config
)
self
.
apply
(
self
.
init_bert_weights
)
self
.
apply
(
self
.
init_bert_weights
)
self
.
output_attentions
=
output_attentions
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
output_all_encoded_layers
=
True
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
output_all_encoded_layers
=
True
):
if
attention_mask
is
None
:
if
attention_mask
is
None
:
...
@@ -731,10 +747,14 @@ class BertModel(BertPreTrainedModel):
...
@@ -731,10 +747,14 @@ class BertModel(BertPreTrainedModel):
encoded_layers
=
self
.
encoder
(
embedding_output
,
encoded_layers
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
,
extended_attention_mask
,
output_all_encoded_layers
=
output_all_encoded_layers
)
output_all_encoded_layers
=
output_all_encoded_layers
)
if
self
.
output_attentions
:
all_attentions
,
encoded_layers
=
encoded_layers
sequence_output
=
encoded_layers
[
-
1
]
sequence_output
=
encoded_layers
[
-
1
]
pooled_output
=
self
.
pooler
(
sequence_output
)
pooled_output
=
self
.
pooler
(
sequence_output
)
if
not
output_all_encoded_layers
:
if
not
output_all_encoded_layers
:
encoded_layers
=
encoded_layers
[
-
1
]
encoded_layers
=
encoded_layers
[
-
1
]
if
self
.
output_attentions
:
return
all_attentions
,
encoded_layers
,
pooled_output
return
encoded_layers
,
pooled_output
return
encoded_layers
,
pooled_output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment