Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fa218e64
Commit
fa218e64
authored
Oct 10, 2019
by
Rémi Louf
Browse files
fix syntax errors
parent
3e1cd824
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
8 deletions
+8
-8
transformers/modeling_bert.py
transformers/modeling_bert.py
+8
-8
No files found.
transformers/modeling_bert.py
View file @
fa218e64
...
...
@@ -201,7 +201,7 @@ class BertSelfAttention(nn.Module):
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
head_mask
=
None
,
encoder_hidden_states
=
None
):
mixed_key_layer
=
self
.
key
(
hidden_states
)
mixed_value_layer
=
self
.
value
(
hidden_states
)
if
encoder_hidden_states
:
# if encoder-decoder attention
if
encoder_hidden_states
is
not
None
:
# if encoder-decoder attention
mixed_query_layer
=
self
.
query
(
encoder_hidden_states
)
else
:
mixed_query_layer
=
self
.
query
(
hidden_states
)
...
...
@@ -331,11 +331,12 @@ class BertLayer(nn.Module):
attention_outputs
=
self
.
attention
(
hidden_states
,
attention_mask
,
head_mask
)
attention_output
=
attention_outputs
[
0
]
if
encoder_hidden_state
:
if
encoder_hidden_state
is
not
None
:
try
:
attention_outputs
=
self
.
crossattention
(
attention_output
,
attention_mask
,
head_mask
,
encoder_hidden_state
)
except
AttributeError
as
ae
:
raise
ae
(
"you need to set `is_encoder` to True in the configuration to instantiate an encoder layer"
)
print
(
"You need to set `is_encoder` to True in the configuration to instantiate an encoder layer:"
,
ae
)
raise
attention_output
=
attention_outputs
[
0
]
intermediate_output
=
self
.
intermediate
(
attention_output
)
...
...
@@ -382,7 +383,7 @@ class BertDecoder(nn.Module):
config
.
is_decoder
=
True
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
layer
s
=
nn
.
ModuleList
([
BertLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
layer
=
nn
.
ModuleList
([
BertLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
hidden_states
,
encoder_outputs
,
attention_mask
=
None
,
head_mask
=
None
):
all_hidden_states
=
()
...
...
@@ -738,7 +739,7 @@ class BertDecoderModel(BertPreTrainedModel):
self
.
decoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
self
.
decoder
.
layer
[
layer
].
crossattention
.
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
encoder_outputs
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
encoder_outputs
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
training
=
False
):
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
)
if
token_type_ids
is
None
:
...
...
@@ -782,7 +783,7 @@ class BertDecoderModel(BertPreTrainedModel):
sequence_output
=
decoder_outputs
[
0
]
pooled_output
=
self
.
pooler
(
sequence_output
)
outputs
=
(
sequence_output
,
pooled_output
,)
+
e
n
coder_outputs
[
1
:]
# add hidden_states and attentions if they are here
outputs
=
(
sequence_output
,
pooled_output
,)
+
d
ecoder_outputs
[
1
:]
# add hidden_states and attentions if they are here
return
outputs
# sequence_output, pooled_output, (hidden_states), (attentions)
...
...
@@ -1387,8 +1388,7 @@ class Bert2Rnd(BertPreTrainedModel):
head_mask
=
head_mask
)
encoder_output
=
encoder_outputs
[
0
]
decoder_input
=
torch
.
empty_like
(
input_ids
).
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
decoder_outputs
=
self
.
decoder
(
decoder_input
,
decoder_outputs
=
self
.
decoder
(
input_ids
,
encoder_output
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment