Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
22933e66
Unverified
Commit
22933e66
authored
Aug 29, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 29, 2020
Browse files
[bart] rename self-attention -> attention (#6708)
parent
0f58903b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
9 deletions
+6
-9
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+6
-9
No files found.
src/transformers/modeling_bart.py
View file @
22933e66
...
...
@@ -225,11 +225,7 @@ class EncoderLayer(nn.Module):
def
__init__
(
self
,
config
:
BartConfig
):
super
().
__init__
()
self
.
embed_dim
=
config
.
d_model
self
.
self_attn
=
SelfAttention
(
self
.
embed_dim
,
config
.
encoder_attention_heads
,
dropout
=
config
.
attention_dropout
,
)
self
.
self_attn
=
Attention
(
self
.
embed_dim
,
config
.
encoder_attention_heads
,
dropout
=
config
.
attention_dropout
)
self
.
normalize_before
=
config
.
normalize_before
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
dropout
=
config
.
dropout
...
...
@@ -377,7 +373,8 @@ class DecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
BartConfig
):
super
().
__init__
()
self
.
embed_dim
=
config
.
d_model
self
.
self_attn
=
SelfAttention
(
self
.
self_attn
=
Attention
(
embed_dim
=
self
.
embed_dim
,
num_heads
=
config
.
decoder_attention_heads
,
dropout
=
config
.
attention_dropout
,
...
...
@@ -388,7 +385,7 @@ class DecoderLayer(nn.Module):
self
.
normalize_before
=
config
.
normalize_before
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
encoder_attn
=
Self
Attention
(
self
.
encoder_attn
=
Attention
(
self
.
embed_dim
,
config
.
decoder_attention_heads
,
dropout
=
config
.
attention_dropout
,
...
...
@@ -586,7 +583,7 @@ class BartDecoder(nn.Module):
if
use_cache
:
next_decoder_cache
.
append
(
layer_past
.
copy
())
if
self
.
layer_norm
and
(
idx
==
len
(
self
.
layers
)
-
1
):
#
last layer of mbart
if
self
.
layer_norm
and
(
idx
==
len
(
self
.
layers
)
-
1
):
#
if config.add_final_layer_norm (mBART)
x
=
self
.
layer_norm
(
x
)
if
output_attentions
:
all_self_attns
+=
(
layer_self_attn
,)
...
...
@@ -616,7 +613,7 @@ def _reorder_buffer(attn_cache, new_order):
return
attn_cache
class
Self
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment