Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
82649658
Commit
82649658
authored
Jun 03, 2019
by
VictorSanh
Browse files
Revert "add output_attentions for BertModel"
This reverts commit
de5e5682
.
parent
de5e5682
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
27 deletions
+7
-27
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+7
-27
No files found.
pytorch_pretrained_bert/modeling.py
View file @
82649658
...
...
@@ -275,7 +275,7 @@ class BertEmbeddings(nn.Module):
class
BertSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
):
super
(
BertSelfAttention
,
self
).
__init__
()
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
:
raise
ValueError
(
...
...
@@ -291,8 +291,6 @@ class BertSelfAttention(nn.Module):
self
.
dropout
=
nn
.
Dropout
(
config
.
attention_probs_dropout_prob
)
self
.
output_attentions
=
output_attentions
def
transpose_for_scores
(
self
,
x
):
new_x_shape
=
x
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads
,
self
.
attention_head_size
)
x
=
x
.
view
(
*
new_x_shape
)
...
...
@@ -324,10 +322,7 @@ class BertSelfAttention(nn.Module):
context_layer
=
context_layer
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
if
self
.
output_attentions
:
return
attention_probs
,
context_layer
else
:
return
context_layer
return
context_layer
class
BertSelfOutput
(
nn
.
Module
):
...
...
@@ -386,43 +381,33 @@ class BertOutput(nn.Module):
class
BertLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
):
super
(
BertLayer
,
self
).
__init__
()
self
.
attention
=
BertAttention
(
config
)
self
.
intermediate
=
BertIntermediate
(
config
)
self
.
output
=
BertOutput
(
config
)
self
.
output_attentions
=
output_attentions
def
forward
(
self
,
hidden_states
,
attention_mask
):
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
)
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
if
self
.
output_attentions
:
return
attention_output
,
layer_output
return
layer_output
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
):
super
(
BertEncoder
,
self
).
__init__
()
layer
=
BertLayer
(
config
,
output_attentions
=
output_attentions
)
layer
=
BertLayer
(
config
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
output_attentions
=
output_attentions
def
forward
(
self
,
hidden_states
,
attention_mask
,
output_all_encoded_layers
=
True
):
all_encoder_layers
=
[]
all_attentions
=
[]
for
layer_module
in
self
.
layer
:
hidden_states
=
layer_module
(
hidden_states
,
attention_mask
)
if
self
.
output_attentions
:
attentions
,
hidden_states
=
hidden_states
all_attentions
.
append
(
attentions
)
if
output_all_encoded_layers
:
all_encoder_layers
.
append
(
hidden_states
)
if
not
output_all_encoded_layers
:
all_encoder_layers
.
append
(
hidden_states
)
if
self
.
output_attentions
:
return
all_attentions
,
all_encoder_layers
return
all_encoder_layers
...
...
@@ -714,13 +699,12 @@ class BertModel(BertPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
):
super
(
BertModel
,
self
).
__init__
(
config
)
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
encoder
=
BertEncoder
(
config
,
output_attentions
=
output_attentions
)
self
.
encoder
=
BertEncoder
(
config
)
self
.
pooler
=
BertPooler
(
config
)
self
.
apply
(
self
.
init_bert_weights
)
self
.
output_attentions
=
output_attentions
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
output_all_encoded_layers
=
True
):
if
attention_mask
is
None
:
...
...
@@ -747,14 +731,10 @@ class BertModel(BertPreTrainedModel):
encoded_layers
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
,
output_all_encoded_layers
=
output_all_encoded_layers
)
if
self
.
output_attentions
:
all_attentions
,
encoded_layers
=
encoded_layers
sequence_output
=
encoded_layers
[
-
1
]
pooled_output
=
self
.
pooler
(
sequence_output
)
if
not
output_all_encoded_layers
:
encoded_layers
=
encoded_layers
[
-
1
]
if
self
.
output_attentions
:
return
all_attentions
,
encoded_layers
,
pooled_output
return
encoded_layers
,
pooled_output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment