Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a051d892
Unverified
Commit
a051d892
authored
Jan 12, 2021
by
Patrick von Platen
Committed by
GitHub
Jan 12, 2021
Browse files
[ProphetNet] Fix naming and wrong config (#9514)
* fix naming issues * better names
parent
7f286132
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
14 deletions
+14
-14
src/transformers/models/prophetnet/modeling_prophetnet.py
src/transformers/models/prophetnet/modeling_prophetnet.py
+14
-14
No files found.
src/transformers/models/prophetnet/modeling_prophetnet.py
View file @
a051d892
...
...
@@ -559,7 +559,7 @@ class ProphetNetPreTrainedModel(PreTrainedModel):
return
shifted_input_ids
class
Pro
h
petNetPositionalEmbeddings
(
nn
.
Embedding
):
class
Prop
h
etNetPositionalEmbeddings
(
nn
.
Embedding
):
"""
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
...
...
@@ -598,7 +598,7 @@ class ProhpetNetPositionalEmbeddings(nn.Embedding):
return
super
().
forward
(
position_ids
)
class
ProphetNet
Self
Attention
(
nn
.
Module
):
class
ProphetNetAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
...
...
@@ -726,7 +726,7 @@ class ProphetNetSelfAttention(nn.Module):
return
attn_output
,
attn_weights_reshaped
class
Pro
h
petNetFeedForward
(
nn
.
Module
):
class
Prop
h
etNetFeedForward
(
nn
.
Module
):
"""
This is the residual two feed-forward layer block based on the original Transformer implementation.
"""
...
...
@@ -749,14 +749,14 @@ class ProhpetNetFeedForward(nn.Module):
return
hidden_states
class
ProphetNetNgram
ProphetNet
SelfAttention
(
nn
.
Module
):
class
ProphetNetNgramSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ProphetNetConfig
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_buckets
=
config
.
num_buckets
self
.
relative_max_distance
=
config
.
relative_max_distance
self
.
num_attn_heads
=
config
.
num_attention_heads
self
.
num_attn_heads
=
config
.
num_
decoder_
attention_heads
self
.
dropout
=
config
.
dropout
self
.
attention_dropout
=
config
.
attention_dropout
self
.
head_dim
=
config
.
hidden_size
//
self
.
num_attn_heads
...
...
@@ -1046,11 +1046,11 @@ class ProphetNetEncoderLayer(nn.Module):
def
__init__
(
self
,
config
:
ProphetNetConfig
):
super
().
__init__
()
# 1st residual block
self
.
self_attn
=
ProphetNet
Self
Attention
(
config
,
config
.
num_encoder_attention_heads
)
self
.
self_attn
=
ProphetNetAttention
(
config
,
config
.
num_encoder_attention_heads
)
self
.
self_attn_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
# 2nd residual block
self
.
feed_forward
=
Pro
h
petNetFeedForward
(
config
,
config
.
encoder_ffn_dim
)
self
.
feed_forward
=
Prop
h
etNetFeedForward
(
config
,
config
.
encoder_ffn_dim
)
self
.
feed_forward_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
def
forward
(
self
,
hidden_states
,
attention_mask
):
...
...
@@ -1075,16 +1075,16 @@ class ProphetNetDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
ProphetNetConfig
):
super
().
__init__
()
# 1st residual block
self
.
self_attn
=
ProphetNetNgram
ProphetNet
SelfAttention
(
config
)
self
.
self_attn
=
ProphetNetNgramSelfAttention
(
config
)
self
.
self_attn_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
# 2nd residual block
if
config
.
add_cross_attention
:
self
.
cross_attn
=
ProphetNet
Self
Attention
(
config
,
config
.
num_decoder_attention_heads
)
self
.
cross_attn
=
ProphetNetAttention
(
config
,
config
.
num_decoder_attention_heads
)
self
.
cross_attn_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
# 3rd residual block
self
.
feed_forward
=
Pro
h
petNetFeedForward
(
config
,
config
.
decoder_ffn_dim
)
self
.
feed_forward
=
Prop
h
etNetFeedForward
(
config
,
config
.
decoder_ffn_dim
)
self
.
feed_forward_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
def
forward
(
...
...
@@ -1156,7 +1156,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
if
word_embeddings
is
not
None
else
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
,
padding_idx
=
config
.
pad_token_id
)
)
self
.
position_embeddings
=
Pro
h
petNetPositionalEmbeddings
(
config
)
self
.
position_embeddings
=
Prop
h
etNetPositionalEmbeddings
(
config
)
self
.
embeddings_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
ProphetNetEncoderLayer
(
config
)
for
_
in
range
(
config
.
num_encoder_layers
)])
...
...
@@ -1212,7 +1212,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
# prepare attention mask
if
attention_mask
is
not
None
:
extended_attention_mask
=
(
1.0
-
attention_mask
[:,
None
,
:].
repeat
(
self
.
config
.
num_attention_heads
,
1
,
1
)
1.0
-
attention_mask
[:,
None
,
:].
repeat
(
self
.
config
.
num_
encoder_
attention_heads
,
1
,
1
)
)
*
-
10000.0
extended_attention_mask
=
extended_attention_mask
.
to
(
inputs_embeds
.
dtype
)
else
:
...
...
@@ -1273,7 +1273,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
if
word_embeddings
is
not
None
else
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
,
padding_idx
=
config
.
pad_token_id
)
)
self
.
position_embeddings
=
Pro
h
petNetPositionalEmbeddings
(
config
)
self
.
position_embeddings
=
Prop
h
etNetPositionalEmbeddings
(
config
)
self
.
ngram_embeddings
=
nn
.
Embedding
(
self
.
ngram
,
config
.
hidden_size
,
None
)
self
.
layers
=
nn
.
ModuleList
([
ProphetNetDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_decoder_layers
)])
...
...
@@ -1397,7 +1397,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
# prepare encoder attention mask
if
encoder_attention_mask
is
not
None
:
extended_encoder_attention_mask
=
(
1.0
-
encoder_attention_mask
[:,
None
,
:].
repeat
(
self
.
config
.
num_attention_heads
,
1
,
1
)
1.0
-
encoder_attention_mask
[:,
None
,
:].
repeat
(
self
.
config
.
num_
decoder_
attention_heads
,
1
,
1
)
)
*
-
10000.0
extended_encoder_attention_mask
=
extended_encoder_attention_mask
.
to
(
inputs_embeds
.
dtype
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment