Unverified Commit b2505f7d authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Cleanup bart caching logic (#5640)

parent 838950ee
......@@ -628,8 +628,8 @@ class SelfAttention(nn.Module):
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
def _shape(self, tensor, dim_0, bsz):
return tensor.contiguous().view(dim_0, bsz * self.num_heads, self.head_dim).transpose(0, 1)
def _shape(self, tensor, seq_len, bsz):
return tensor.contiguous().view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
def forward(
self,
......@@ -648,10 +648,9 @@ class SelfAttention(nn.Module):
# get here for encoder decoder cause of static_kv
if layer_state is not None: # reuse k,v and encoder_padding_mask
saved_state = layer_state.get(self.cache_key, {})
if "prev_key" in saved_state:
if "prev_key" in saved_state and static_kv:
# previous time steps are cached - no need to recompute key and value if they are static
if static_kv:
key = None
key = None
else:
saved_state = None
layer_state = {}
......@@ -738,37 +737,14 @@ class SelfAttention(nn.Module):
v = torch.cat([prev_value, v], dim=1)
assert k is not None and v is not None
prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None)
key_padding_mask = self._cat_prev_key_padding_mask(
key_padding_mask, prev_key_padding_mask, bsz, k.size(1), static_kv
)
return k, v, key_padding_mask
@staticmethod
def _cat_prev_key_padding_mask(
key_padding_mask: Optional[Tensor],
prev_key_padding_mask: Optional[Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None:
if static_kv:
new_key_padding_mask = prev_key_padding_mask
else:
new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)
elif key_padding_mask is not None:
filler = torch.zeros(
batch_size,
src_len - key_padding_mask.size(1),
dtype=key_padding_mask.dtype,
device=key_padding_mask.device,
)
new_key_padding_mask = torch.cat([filler, key_padding_mask], dim=1)
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
new_key_padding_mask = key_padding_mask
return k, v, new_key_padding_mask
class BartClassificationHead(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment