Unverified Commit 63f4d8ca authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[Bart/Memory] SelfAttention only returns weights if config.outp… (#3369)

parent 2b2a2f8d
...@@ -217,7 +217,9 @@ class EncoderLayer(nn.Module): ...@@ -217,7 +217,9 @@ class EncoderLayer(nn.Module):
encoded output of shape `(seq_len, batch, embed_dim)` encoded output of shape `(seq_len, batch, embed_dim)`
""" """
residual = x residual = x
x, attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask,) x, attn_weights = self.self_attn(
query=x, key=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions
)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
...@@ -316,6 +318,7 @@ class DecoderLayer(nn.Module): ...@@ -316,6 +318,7 @@ class DecoderLayer(nn.Module):
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.output_attentions = config.output_attentions
self.self_attn = SelfAttention( self.self_attn = SelfAttention(
embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout,
) )
...@@ -343,14 +346,16 @@ class DecoderLayer(nn.Module): ...@@ -343,14 +346,16 @@ class DecoderLayer(nn.Module):
if layer_state is None: if layer_state is None:
layer_state = {} layer_state = {}
# next line mutates layer state # next line mutates layer state
x, self_attn_weights = self.self_attn(query=x, key=x, layer_state=layer_state, attn_mask=attention_mask,) x, self_attn_weights = self.self_attn(
query=x, key=x, layer_state=layer_state, attn_mask=attention_mask, need_weights=self.output_attentions
)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
residual = x residual = x
assert self.encoder_attn.cache_key != self.self_attn.cache_key assert self.encoder_attn.cache_key != self.self_attn.cache_key
x, encoder_attn_weights = self.encoder_attn( x, _ = self.encoder_attn(
query=x, query=x,
key=encoder_hidden_states, key=encoder_hidden_states,
key_padding_mask=encoder_attn_mask, key_padding_mask=encoder_attn_mask,
...@@ -527,6 +532,7 @@ class SelfAttention(nn.Module): ...@@ -527,6 +532,7 @@ class SelfAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
layer_state: Optional[Dict[str, Optional[Tensor]]] = None, layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
need_weights=False,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time(SeqLen) x Batch x Channel""" """Input shape: Time(SeqLen) x Batch x Channel"""
static_kv = self.encoder_decoder_attention # type: bool static_kv = self.encoder_decoder_attention # type: bool
...@@ -598,7 +604,10 @@ class SelfAttention(nn.Module): ...@@ -598,7 +604,10 @@ class SelfAttention(nn.Module):
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim) assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
if need_weights:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
else:
attn_weights = None
return attn_output, attn_weights return attn_output, attn_weights
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz): def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment