"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6596e3d56626c921b3920e313866b7412633b91a"
Unverified Commit a051d892 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[ProphetNet] Fix naming and wrong config (#9514)

* fix naming issues

* better names
parent 7f286132
...@@ -559,7 +559,7 @@ class ProphetNetPreTrainedModel(PreTrainedModel): ...@@ -559,7 +559,7 @@ class ProphetNetPreTrainedModel(PreTrainedModel):
return shifted_input_ids return shifted_input_ids
class ProhpetNetPositionalEmbeddings(nn.Embedding): class ProphetNetPositionalEmbeddings(nn.Embedding):
""" """
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
...@@ -598,7 +598,7 @@ class ProhpetNetPositionalEmbeddings(nn.Embedding): ...@@ -598,7 +598,7 @@ class ProhpetNetPositionalEmbeddings(nn.Embedding):
return super().forward(position_ids) return super().forward(position_ids)
class ProphetNetSelfAttention(nn.Module): class ProphetNetAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__( def __init__(
...@@ -726,7 +726,7 @@ class ProphetNetSelfAttention(nn.Module): ...@@ -726,7 +726,7 @@ class ProphetNetSelfAttention(nn.Module):
return attn_output, attn_weights_reshaped return attn_output, attn_weights_reshaped
class ProhpetNetFeedForward(nn.Module): class ProphetNetFeedForward(nn.Module):
""" """
This is the residual two feed-forward layer block based on the original Transformer implementation. This is the residual two feed-forward layer block based on the original Transformer implementation.
""" """
...@@ -749,14 +749,14 @@ class ProhpetNetFeedForward(nn.Module): ...@@ -749,14 +749,14 @@ class ProhpetNetFeedForward(nn.Module):
return hidden_states return hidden_states
class ProphetNetNgramProphetNetSelfAttention(nn.Module): class ProphetNetNgramSelfAttention(nn.Module):
def __init__(self, config: ProphetNetConfig): def __init__(self, config: ProphetNetConfig):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.num_buckets = config.num_buckets self.num_buckets = config.num_buckets
self.relative_max_distance = config.relative_max_distance self.relative_max_distance = config.relative_max_distance
self.num_attn_heads = config.num_attention_heads self.num_attn_heads = config.num_decoder_attention_heads
self.dropout = config.dropout self.dropout = config.dropout
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
self.head_dim = config.hidden_size // self.num_attn_heads self.head_dim = config.hidden_size // self.num_attn_heads
...@@ -1046,11 +1046,11 @@ class ProphetNetEncoderLayer(nn.Module): ...@@ -1046,11 +1046,11 @@ class ProphetNetEncoderLayer(nn.Module):
def __init__(self, config: ProphetNetConfig): def __init__(self, config: ProphetNetConfig):
super().__init__() super().__init__()
# 1st residual block # 1st residual block
self.self_attn = ProphetNetSelfAttention(config, config.num_encoder_attention_heads) self.self_attn = ProphetNetAttention(config, config.num_encoder_attention_heads)
self.self_attn_layer_norm = LayerNorm(config.hidden_size) self.self_attn_layer_norm = LayerNorm(config.hidden_size)
# 2nd residual block # 2nd residual block
self.feed_forward = ProhpetNetFeedForward(config, config.encoder_ffn_dim) self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim)
self.feed_forward_layer_norm = LayerNorm(config.hidden_size) self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
def forward(self, hidden_states, attention_mask): def forward(self, hidden_states, attention_mask):
...@@ -1075,16 +1075,16 @@ class ProphetNetDecoderLayer(nn.Module): ...@@ -1075,16 +1075,16 @@ class ProphetNetDecoderLayer(nn.Module):
def __init__(self, config: ProphetNetConfig): def __init__(self, config: ProphetNetConfig):
super().__init__() super().__init__()
# 1st residual block # 1st residual block
self.self_attn = ProphetNetNgramProphetNetSelfAttention(config) self.self_attn = ProphetNetNgramSelfAttention(config)
self.self_attn_layer_norm = LayerNorm(config.hidden_size) self.self_attn_layer_norm = LayerNorm(config.hidden_size)
# 2nd residual block # 2nd residual block
if config.add_cross_attention: if config.add_cross_attention:
self.cross_attn = ProphetNetSelfAttention(config, config.num_decoder_attention_heads) self.cross_attn = ProphetNetAttention(config, config.num_decoder_attention_heads)
self.cross_attn_layer_norm = LayerNorm(config.hidden_size) self.cross_attn_layer_norm = LayerNorm(config.hidden_size)
# 3rd residual block # 3rd residual block
self.feed_forward = ProhpetNetFeedForward(config, config.decoder_ffn_dim) self.feed_forward = ProphetNetFeedForward(config, config.decoder_ffn_dim)
self.feed_forward_layer_norm = LayerNorm(config.hidden_size) self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
def forward( def forward(
...@@ -1156,7 +1156,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): ...@@ -1156,7 +1156,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
if word_embeddings is not None if word_embeddings is not None
else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
) )
self.position_embeddings = ProhpetNetPositionalEmbeddings(config) self.position_embeddings = ProphetNetPositionalEmbeddings(config)
self.embeddings_layer_norm = LayerNorm(config.hidden_size) self.embeddings_layer_norm = LayerNorm(config.hidden_size)
self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)]) self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)])
...@@ -1212,7 +1212,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): ...@@ -1212,7 +1212,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
# prepare attention mask # prepare attention mask
if attention_mask is not None: if attention_mask is not None:
extended_attention_mask = ( extended_attention_mask = (
1.0 - attention_mask[:, None, :].repeat(self.config.num_attention_heads, 1, 1) 1.0 - attention_mask[:, None, :].repeat(self.config.num_encoder_attention_heads, 1, 1)
) * -10000.0 ) * -10000.0
extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype) extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype)
else: else:
...@@ -1273,7 +1273,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1273,7 +1273,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
if word_embeddings is not None if word_embeddings is not None
else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
) )
self.position_embeddings = ProhpetNetPositionalEmbeddings(config) self.position_embeddings = ProphetNetPositionalEmbeddings(config)
self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None) self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None)
self.layers = nn.ModuleList([ProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)]) self.layers = nn.ModuleList([ProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)])
...@@ -1397,7 +1397,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1397,7 +1397,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
# prepare encoder attention mask # prepare encoder attention mask
if encoder_attention_mask is not None: if encoder_attention_mask is not None:
extended_encoder_attention_mask = ( extended_encoder_attention_mask = (
1.0 - encoder_attention_mask[:, None, :].repeat(self.config.num_attention_heads, 1, 1) 1.0 - encoder_attention_mask[:, None, :].repeat(self.config.num_decoder_attention_heads, 1, 1)
) * -10000.0 ) * -10000.0
extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype) extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment