Unverified Commit 22933e66 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[bart] rename self-attention -> attention (#6708)

parent 0f58903b
...@@ -225,11 +225,7 @@ class EncoderLayer(nn.Module): ...@@ -225,11 +225,7 @@ class EncoderLayer(nn.Module):
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = SelfAttention( self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout)
self.embed_dim,
config.encoder_attention_heads,
dropout=config.attention_dropout,
)
self.normalize_before = config.normalize_before self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = config.dropout self.dropout = config.dropout
...@@ -377,7 +373,8 @@ class DecoderLayer(nn.Module): ...@@ -377,7 +373,8 @@ class DecoderLayer(nn.Module):
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = SelfAttention(
self.self_attn = Attention(
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
...@@ -388,7 +385,7 @@ class DecoderLayer(nn.Module): ...@@ -388,7 +385,7 @@ class DecoderLayer(nn.Module):
self.normalize_before = config.normalize_before self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.encoder_attn = SelfAttention( self.encoder_attn = Attention(
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
...@@ -586,7 +583,7 @@ class BartDecoder(nn.Module): ...@@ -586,7 +583,7 @@ class BartDecoder(nn.Module):
if use_cache: if use_cache:
next_decoder_cache.append(layer_past.copy()) next_decoder_cache.append(layer_past.copy())
if self.layer_norm and (idx == len(self.layers) - 1): # last layer of mbart if self.layer_norm and (idx == len(self.layers) - 1): # if config.add_final_layer_norm (mBART)
x = self.layer_norm(x) x = self.layer_norm(x)
if output_attentions: if output_attentions:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
...@@ -616,7 +613,7 @@ def _reorder_buffer(attn_cache, new_order): ...@@ -616,7 +613,7 @@ def _reorder_buffer(attn_cache, new_order):
return attn_cache return attn_cache
class SelfAttention(nn.Module): class Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__( def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment