"images/vscode:/vscode.git/clone" did not exist on "d1826c0ed54e23697376c81f73788f88024696c2"
Unverified Commit 63f4d8ca authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[Bart/Memory] SelfAttention only returns weights if config.outp… (#3369)

parent 2b2a2f8d
......@@ -217,7 +217,9 @@ class EncoderLayer(nn.Module):
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
x, attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask,)
x, attn_weights = self.self_attn(
query=x, key=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.self_attn_layer_norm(x)
......@@ -316,6 +318,7 @@ class DecoderLayer(nn.Module):
def __init__(self, config: BartConfig):
super().__init__()
self.embed_dim = config.d_model
self.output_attentions = config.output_attentions
self.self_attn = SelfAttention(
embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout,
)
......@@ -343,14 +346,16 @@ class DecoderLayer(nn.Module):
if layer_state is None:
layer_state = {}
# next line mutates layer state
x, self_attn_weights = self.self_attn(query=x, key=x, layer_state=layer_state, attn_mask=attention_mask,)
x, self_attn_weights = self.self_attn(
query=x, key=x, layer_state=layer_state, attn_mask=attention_mask, need_weights=self.output_attentions
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
assert self.encoder_attn.cache_key != self.self_attn.cache_key
x, encoder_attn_weights = self.encoder_attn(
x, _ = self.encoder_attn(
query=x,
key=encoder_hidden_states,
key_padding_mask=encoder_attn_mask,
......@@ -527,6 +532,7 @@ class SelfAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None,
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
attn_mask: Optional[Tensor] = None,
need_weights=False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time(SeqLen) x Batch x Channel"""
static_kv = self.encoder_decoder_attention # type: bool
......@@ -598,7 +604,10 @@ class SelfAttention(nn.Module):
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = self.out_proj(attn_output)
if need_weights:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
else:
attn_weights = None
return attn_output, attn_weights
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment