Unverified Commit 4c7f564f authored by ZhuBaohe's avatar ZhuBaohe Committed by GitHub
Browse files

fix (#4839)

parent 37be3786
......@@ -153,12 +153,11 @@ class LongformerSelfAttention(nn.Module):
beginning_mask_2d = input_tensor.new_ones(w, w + 1).tril().flip(dims=[0])
beginning_mask = beginning_mask_2d[None, :, None, :]
ending_mask = beginning_mask.flip(dims=(1, 3))
seqlen = input_tensor.size(1)
beginning_input = input_tensor[:, :affected_seqlen, :, : w + 1]
beginning_mask = beginning_mask[:, :seqlen].expand(beginning_input.size())
beginning_mask = beginning_mask.expand(beginning_input.size())
beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
ending_input = input_tensor[:, -affected_seqlen:, :, -(w + 1) :]
ending_mask = ending_mask[:, -seqlen:].expand(ending_input.size())
ending_mask = ending_mask.expand(ending_input.size())
ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
def _sliding_chunks_matmul_qk(self, q: torch.Tensor, k: torch.Tensor, w: int):
......@@ -301,7 +300,6 @@ class LongformerSelfAttention(nn.Module):
k = k.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
# attn_weights = (batch_size, seqlen, num_heads, window*2+1)
attn_weights = self._sliding_chunks_matmul_qk(q, k, self.one_sided_attention_window_size)
self._mask_invalid_locations(attn_weights, self.one_sided_attention_window_size)
if remove_from_windowed_attention_mask is not None:
# This implementation is fast and takes very little memory because num_heads x hidden_size = 1
# from (batch_size x seqlen) to (batch_size x seqlen x num_heads x hidden_size)
......@@ -329,7 +327,7 @@ class LongformerSelfAttention(nn.Module):
selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros]
# (batch_size, seqlen, num_heads, max_num_extra_indices_per_batch)
selected_attn_weights = torch.einsum("blhd,bshd->blhs", (q, selected_k))
selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000
selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000.0
# concat to attn_weights
# (batch_size, seqlen, num_heads, extra attention count + 2*window+1)
attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment