Unverified Commit f033498f authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Fix get_swa_mask() for padding masks (#1281)



* WIP: fix get_swa_mask for padding
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix mask type setting
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix the order of checking valid swa and changing mask type
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revamp to get full mask
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 83dac8cf
...@@ -531,18 +531,22 @@ def test_dpa_bias_shapes(dtype, model_configs, model): ...@@ -531,18 +531,22 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
model_configs_swa = { model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), "swa_1_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), "swa_1_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "swa_2_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), "swa_2_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), "swa_3_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"), "swa_3_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"swa_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), "swa_4_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), "swa_4_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"swa_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal_bottom_right", "no_bias"), "swa_5_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal_bottom_right", "no_bias"), "swa_5_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"swa_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), "swa_6_0": ModelConfig(
"swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), 4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"swa_6_1": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
} }
......
...@@ -1024,27 +1024,51 @@ class InferenceParams: # pylint: disable=too-few-public-methods ...@@ -1024,27 +1024,51 @@ class InferenceParams: # pylint: disable=too-few-public-methods
@torch.no_grad() @torch.no_grad()
def get_swa_mask( def get_full_mask(
window_size: Tuple[int, int],
max_seqlen_q: int, max_seqlen_q: int,
max_seqlen_kv: int, max_seqlen_kv: int,
attn_mask_type: str = "no_mask", attn_mask_type: str = "no_mask",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,
window_size: Tuple[int, int] = None,
attention_type: str = "self",
bottom_right_alignment: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Convert sliding window `window_size` to an equivalent "`arbitrary`" mask. Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`,
For "`causal`" mask type, the sliding window diagonal is aligned to the top left corner, `attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends
and for other mask types, the bottom right corner. on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.::
attn_mask_type output shape diagonal alignment
--------------------------------------------------------------------------------------------
no_mask [1, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
causal [1, 1, max_seqlen_q, max_seqlen_kv] always top left
causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] always bottom right
padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left
padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right
arbitrary same as attention_mask follow bottom_right_alignment
.. note::
For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right
diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix,
i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4,
max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = (
[[False, False, True, True], [False, False, False, False]],
[[False, False, False, True], [False, True, True, True]]), the returned full attention mask has [2, 4, 4]
shape and is,::
[[[False, False, False, True],
[False, False, False, True],
[ True, True, True, True],
[ True, True, True, True]],
[[False, True, True, True],
[False, True, True, True],
[False, True, True, True],
[False, True, True, True]]]
Parameters Parameters
---------- ----------
window_size: Tuple[int, int]
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`.
max_seqlen_q: int max_seqlen_q: int
Maximum sequence length for queries. Maximum sequence length for queries.
max_seqlen_kv: int max_seqlen_kv: int
...@@ -1052,33 +1076,105 @@ def get_swa_mask( ...@@ -1052,33 +1076,105 @@ def get_swa_mask(
attn_mask_type: str, default = `no_mask` attn_mask_type: str, default = `no_mask`
Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`", Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
"`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"} "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
default = `None` default = `None`
Boolean tensor(s) used to mask out attention softmax input. Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention
for the requirements of `attention_mask` for different `attn_mask_type`s.
window_size: Tuple[int, int], default = `None`
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`.
attention_type: str, default = "self"
Attention type, {"self", "cross"}
bottom_right_alignment: bool, default = `True`
Whether to align the diagonal of the sliding window attention to the bottom right (`True`)
or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly
specifies "causal" or "causal_bottom_right".
Returns Returns
---------- ----------
attn_mask_type: str
For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type`
attention_mask: torch.Tensor attention_mask: torch.Tensor
Combined `attention_mask` (input) and sliding window attention mask. The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size`
The shape is [max_seqlen_q, max_seqlen_kv] when input `attention_mask` is None; actual_seqlens_q: torch.Tensor
else, the same shape as input `attention_mask`. For padding masks, the actual sequence lengths for queries, in shape [batch_size].
For other masks, `None`.
actual_seqlens_kv: Optional[torch.Tensor], default = `None`
For padding masks, the actual sequence lengths for keys and values, in shape [batch_size].
For other masks, `None`.
""" """
mask = torch.ones(max_seqlen_q, max_seqlen_kv, dtype=torch.bool, device="cuda") # perform basic checks
if attn_mask_type in ["causal"]: change_type = window_size is not None and (
left = window_size[0] if window_size[0] != -1 else max_seqlen_q window_size[0] != -1 or window_size[1] not in [-1, 0]
right = window_size[1] if window_size[1] != -1 else max_seqlen_q )
mask_upper = torch.triu(mask, diagonal=-left) if window_size is None:
mask_lower = torch.tril(mask_upper, diagonal=right) window_size = (-1, -1)
else: if "causal" in attn_mask_type:
left = window_size[0] if window_size[0] != -1 else max_seqlen_kv window_size = (window_size[0], 0)
right = window_size[1] if window_size[1] != -1 else max_seqlen_kv window_size = (
mask_upper = torch.triu(mask, diagonal=max_seqlen_kv - max_seqlen_q - left) max_seqlen_kv if window_size[0] == -1 else window_size[0],
mask_lower = torch.tril(mask_upper, diagonal=max_seqlen_kv - max_seqlen_q + right) max_seqlen_q if window_size[1] == -1 else window_size[1],
attn_mask_type = "arbitrary" )
mask = mask_lower.logical_not()
# apply padding mask
actual_seqlens_q = None
actual_seqlens_kv = None
if "padding" in attn_mask_type:
if attention_type == "self":
attention_mask = torch.logical_or(
attention_mask.squeeze(1).unsqueeze(3), attention_mask
)
else:
attention_mask = torch.logical_or(
attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
)
m = attention_mask.logical_not()
actual_seqlens_q = m[:, 0, :, 0].sum(dim=1)
actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1)
# apply SWA mask
mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv)
swa_left = None
swa_right = None
if attn_mask_type == "causal_bottom_right" or (
attn_mask_type in ["no_mask", "arbitrary"] and bottom_right_alignment
):
swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0]
swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1]
elif attn_mask_type in ["causal", "padding_causal"] or (
attn_mask_type in ["no_mask", "padding", "arbitrary"] and not bottom_right_alignment
):
swa_left = mask - window_size[0]
swa_right = mask + window_size[1]
elif attn_mask_type == "padding_causal_bottom_right" or (
attn_mask_type == "padding" and bottom_right_alignment
):
batch_size = attention_mask.shape[0]
swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
actual_seqlens_kv - actual_seqlens_q - window_size[0]
).view(batch_size, 1, 1, 1)
swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
actual_seqlens_kv - actual_seqlens_q + window_size[1]
).view(batch_size, 1, 1, 1)
swa_mask = torch.logical_not(
torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0)
)
if attention_mask is not None: if attention_mask is not None:
mask = torch.logical_and(attention_mask, mask) attention_mask = torch.logical_or(swa_mask, attention_mask)
return attn_mask_type, mask else:
attention_mask = swa_mask
# change mask type
if change_type:
attn_mask_type = "arbitrary"
return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv
@torch.no_grad() @torch.no_grad()
...@@ -4733,6 +4829,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -4733,6 +4829,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
window_size: Optional[Tuple[int, int]] = None,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
...@@ -4752,52 +4849,14 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -4752,52 +4849,14 @@ class UnfusedDotProductAttention(torch.nn.Module):
query_layer.shape[0], query_layer.shape[0],
key_layer.shape[0], key_layer.shape[0],
) )
if "padding" in attn_mask_type:
if self.attention_type == "self": attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_full_mask(
assert attention_mask.shape == (
batch_size,
1,
1,
max_seqlen_q, max_seqlen_q,
), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!" max_seqlen_kv,
attention_mask = torch.logical_or( attn_mask_type=attn_mask_type,
attention_mask.squeeze(1).unsqueeze(3), attention_mask attention_mask=attention_mask,
) window_size=window_size,
else: attention_type=self.attention_type,
assert (
len(attention_mask) == 2
and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q)
and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv)
), (
"attention_mask should be a tuple of two tensors with shapes "
"[b, 1, 1, sq] and [b, 1, 1, skv]!"
)
attention_mask = torch.logical_or(
attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
)
mask = attention_mask.squeeze(1).logical_not()
actual_seqlens_q = mask[:, :, 0].sum(dim=1)
actual_seqlens_kv = mask[:, 0, :].sum(dim=1)
mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
1, 1, 1, max_seqlen_kv
)
if attn_mask_type == "padding_causal":
attention_mask = torch.logical_or(
torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0),
attention_mask,
)
if attn_mask_type == "padding_causal_bottom_right":
attention_mask = torch.logical_or(
torch.where(
mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
+ (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
< 0,
1,
0,
),
attention_mask,
) )
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
...@@ -8274,12 +8333,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -8274,12 +8333,6 @@ class DotProductAttention(TransformerEngineBaseModule):
) )
if use_unfused_attention: if use_unfused_attention:
if window_size is not None and (
window_size[0] != -1 or window_size[1] not in [-1, 0]
):
attn_mask_type, attention_mask = get_swa_mask(
window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask
)
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
self.unfused_attention, self.unfused_attention,
...@@ -8291,6 +8344,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -8291,6 +8344,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attention_mask=attention_mask, attention_mask=attention_mask,
window_size=window_size,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
...@@ -8304,6 +8358,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -8304,6 +8358,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attention_mask=attention_mask, attention_mask=attention_mask,
window_size=window_size,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment