Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a051d892
Unverified
Commit
a051d892
authored
Jan 12, 2021
by
Patrick von Platen
Committed by
GitHub
Jan 12, 2021
Browse files
[ProphetNet] Fix naming and wrong config (#9514)
* fix naming issues * better names
parent
7f286132
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
14 deletions
+14
-14
src/transformers/models/prophetnet/modeling_prophetnet.py
src/transformers/models/prophetnet/modeling_prophetnet.py
+14
-14
No files found.
src/transformers/models/prophetnet/modeling_prophetnet.py
View file @
a051d892
...
...
@@ -559,7 +559,7 @@ class ProphetNetPreTrainedModel(PreTrainedModel):
return
shifted_input_ids
class
Pro
h
petNetPositionalEmbeddings
(
nn
.
Embedding
):
class
Prop
h
etNetPositionalEmbeddings
(
nn
.
Embedding
):
"""
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
...
...
@@ -598,7 +598,7 @@ class ProhpetNetPositionalEmbeddings(nn.Embedding):
return
super
().
forward
(
position_ids
)
class
ProphetNet
Self
Attention
(
nn
.
Module
):
class
ProphetNetAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
...
...
@@ -726,7 +726,7 @@ class ProphetNetSelfAttention(nn.Module):
return
attn_output
,
attn_weights_reshaped
class
Pro
h
petNetFeedForward
(
nn
.
Module
):
class
Prop
h
etNetFeedForward
(
nn
.
Module
):
"""
This is the residual two feed-forward layer block based on the original Transformer implementation.
"""
...
...
@@ -749,14 +749,14 @@ class ProhpetNetFeedForward(nn.Module):
return
hidden_states
class
ProphetNetNgram
ProphetNet
SelfAttention
(
nn
.
Module
):
class
ProphetNetNgramSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ProphetNetConfig
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_buckets
=
config
.
num_buckets
self
.
relative_max_distance
=
config
.
relative_max_distance
self
.
num_attn_heads
=
config
.
num_attention_heads
self
.
num_attn_heads
=
config
.
num_
decoder_
attention_heads
self
.
dropout
=
config
.
dropout
self
.
attention_dropout
=
config
.
attention_dropout
self
.
head_dim
=
config
.
hidden_size
//
self
.
num_attn_heads
...
...
@@ -1046,11 +1046,11 @@ class ProphetNetEncoderLayer(nn.Module):
def
__init__
(
self
,
config
:
ProphetNetConfig
):
super
().
__init__
()
# 1st residual block
self
.
self_attn
=
ProphetNet
Self
Attention
(
config
,
config
.
num_encoder_attention_heads
)
self
.
self_attn
=
ProphetNetAttention
(
config
,
config
.
num_encoder_attention_heads
)
self
.
self_attn_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
# 2nd residual block
self
.
feed_forward
=
Pro
h
petNetFeedForward
(
config
,
config
.
encoder_ffn_dim
)
self
.
feed_forward
=
Prop
h
etNetFeedForward
(
config
,
config
.
encoder_ffn_dim
)
self
.
feed_forward_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
def
forward
(
self
,
hidden_states
,
attention_mask
):
...
...
@@ -1075,16 +1075,16 @@ class ProphetNetDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
ProphetNetConfig
):
super
().
__init__
()
# 1st residual block
self
.
self_attn
=
ProphetNetNgram
ProphetNet
SelfAttention
(
config
)
self
.
self_attn
=
ProphetNetNgramSelfAttention
(
config
)
self
.
self_attn_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
# 2nd residual block
if
config
.
add_cross_attention
:
self
.
cross_attn
=
ProphetNet
Self
Attention
(
config
,
config
.
num_decoder_attention_heads
)
self
.
cross_attn
=
ProphetNetAttention
(
config
,
config
.
num_decoder_attention_heads
)
self
.
cross_attn_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
# 3rd residual block
self
.
feed_forward
=
Pro
h
petNetFeedForward
(
config
,
config
.
decoder_ffn_dim
)
self
.
feed_forward
=
Prop
h
etNetFeedForward
(
config
,
config
.
decoder_ffn_dim
)
self
.
feed_forward_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
def
forward
(
...
...
@@ -1156,7 +1156,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
if
word_embeddings
is
not
None
else
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
,
padding_idx
=
config
.
pad_token_id
)
)
self
.
position_embeddings
=
Pro
h
petNetPositionalEmbeddings
(
config
)
self
.
position_embeddings
=
Prop
h
etNetPositionalEmbeddings
(
config
)
self
.
embeddings_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
ProphetNetEncoderLayer
(
config
)
for
_
in
range
(
config
.
num_encoder_layers
)])
...
...
@@ -1212,7 +1212,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
# prepare attention mask
if
attention_mask
is
not
None
:
extended_attention_mask
=
(
1.0
-
attention_mask
[:,
None
,
:].
repeat
(
self
.
config
.
num_attention_heads
,
1
,
1
)
1.0
-
attention_mask
[:,
None
,
:].
repeat
(
self
.
config
.
num_
encoder_
attention_heads
,
1
,
1
)
)
*
-
10000.0
extended_attention_mask
=
extended_attention_mask
.
to
(
inputs_embeds
.
dtype
)
else
:
...
...
@@ -1273,7 +1273,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
if
word_embeddings
is
not
None
else
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
,
padding_idx
=
config
.
pad_token_id
)
)
self
.
position_embeddings
=
Pro
h
petNetPositionalEmbeddings
(
config
)
self
.
position_embeddings
=
Prop
h
etNetPositionalEmbeddings
(
config
)
self
.
ngram_embeddings
=
nn
.
Embedding
(
self
.
ngram
,
config
.
hidden_size
,
None
)
self
.
layers
=
nn
.
ModuleList
([
ProphetNetDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_decoder_layers
)])
...
...
@@ -1397,7 +1397,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
# prepare encoder attention mask
if
encoder_attention_mask
is
not
None
:
extended_encoder_attention_mask
=
(
1.0
-
encoder_attention_mask
[:,
None
,
:].
repeat
(
self
.
config
.
num_attention_heads
,
1
,
1
)
1.0
-
encoder_attention_mask
[:,
None
,
:].
repeat
(
self
.
config
.
num_
decoder_
attention_heads
,
1
,
1
)
)
*
-
10000.0
extended_encoder_attention_mask
=
extended_encoder_attention_mask
.
to
(
inputs_embeds
.
dtype
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment