Unverified Commit 467b39a3 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Add support for padding mask in `UnfusedDotProductAttention` (#1073)



* add support for padding in UnfusedDPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add support for padding_causal/_bottom_right
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix padding_causal/_bottom_right
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* need to test max512 backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix mask logic in unfused
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* use actual_seqlen for alibi/causal_bottom_right padding
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes and convert causal to causal_bottom_right for inference
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* use causal in kv cache inference test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* simplify get_alibi logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* simplify the non-padding path for get_alibi
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* avoid batch_size loop in generating padding_causal/_bottom_right masks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 26c8fcc9
......@@ -1655,8 +1655,8 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
ffn_hidden_size=4 * D,
num_attention_heads=H,
attn_input_format=input_format,
self_attn_mask_type="causal_bottom_right",
enc_dec_attn_mask_type="causal_bottom_right",
self_attn_mask_type="causal",
enc_dec_attn_mask_type="causal",
layer_number=layer_number,
attention_dropout=0.0,
params_dtype=dtype,
......@@ -1670,7 +1670,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
qkv_format=input_format,
layer_number=layer_number,
attention_dropout=0.0,
attn_mask_type="causal_bottom_right",
attn_mask_type="causal",
params_dtype=dtype,
)
.cuda()
......
......@@ -142,7 +142,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
(bias_type == NVTE_Bias_Type::NVTE_ALIBI &&
attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && sm_arch_ >= 90) ||
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
sm_arch_ >= 90) ||
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) ||
((cudnn_runtime_version >= 90000) &&
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) &&
......
......@@ -472,19 +472,25 @@ def get_attention_backend(
use_fused_attention = False
# Filter: Attention mask
# attn_mask_type | supported backends
# -------------------------------------------------------------------
# no_mask | All
# padding | FlashAttention, FusedAttention
# causal |
# self-attention | All
# cross-attention | FusedAttention
# padding_causal |
# self-attention | FlashAttention, FusedAttention
# cross-attention | FusedAttention
# causal_bottom_right | All
# padding_causal_bottom_right | FlashAttention, FusedAttention
# arbitrary | UnfusedDotProductAttention
# attn_mask_type | attention_mask | supported backends
# ----------------------------------------------------------------------------------------
# no_mask | None | All
# padding | | All
# self-attention | One tensor in shape [b, 1, 1, sq] |
# cross-attention | Tuple of two tensors in shapes |
# | [b, 1, 1, sq] and [b, 1, 1, skv] |
# causal | None |
# self-attention | | All
# cross-attention | | FusedAttention, UnfusedDotProductAttention
# padding_causal | Same as "padding" |
# self-attention | | All
# cross-attention | | FusedAttention, UnfusedDotProductAttention
# causal_bottom_right | None | All
# padding_causal_bottom_right | Same as "padding" |
# self-attention | | All
# cross-attention | | FlashAttention, UnfusedDotProductAttention
# arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention
# | [b, h, sq, skv] |
if attn_mask_type == "arbitrary":
if use_flash_attention:
logger.debug("Disabling FlashAttention for arbitrary mask")
......@@ -492,9 +498,6 @@ def get_attention_backend(
if use_fused_attention:
logger.debug("Disabling FusedAttention for arbitrary mask")
use_fused_attention = False
if use_unfused_attention and "padding" in attn_mask_type:
logger.debug("Disabling UnfusedDotProductAttention for %s mask", attn_mask_type)
use_unfused_attention = False
if (
use_flash_attention
and _flash_attn_2_1_plus
......@@ -780,7 +783,7 @@ def get_attention_backend(
class InferenceParams: # pylint: disable=too-few-public-methods
"""
Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference.
to efficiently calculate and store the context during inference.
Parameters
----------
......@@ -886,6 +889,8 @@ def get_alibi(
num_heads: int,
max_seqlen_q: int,
max_seqlen_kv: int,
actual_seqlens_q: Optional[torch.Tensor] = None,
actual_seqlens_kv: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
bias_dtype: Optional[torch.dtype] = None,
bottom_right_alignment: bool = True,
......@@ -899,6 +904,10 @@ def get_alibi(
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
actual_seqlens_q: Optional[torch.Tensor], default = `None`
Actual sequence lengths for queries, in shape [batch_size].
actual_seqlens_kv: Optional[torch.Tensor], default = `None`
Actual sequence lengths for keys and values, in shape [batch_size].
alibi_slopes: Optional[torch.Tensor], default = `None`
Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
bias_dtype: Optional[torch.dtype], default = `None`
......@@ -912,10 +921,12 @@ def get_alibi(
alibi_slopes: torch.Tensor
ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
alibi_bias: torch.Tensor
ALiBi bias in FP32 or `bias_dtype`. If `alibi_slopes` is in [num_heads] shape,
then `alibi_bias` is in [1, num_heads, max_seqlen_q, max_seqlen_kv], and if
`alibi_slopes` is in [batch_size, num_heads], then the bias is in
[batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
ALiBi bias in FP32 or `bias_dtype`. Its shape is
(1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape,
and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or
(2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in
[batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and
`actual_seqlens_q` and `actual_seqlens_kv` are not `None`.
"""
global _alibi_cache
if _alibi_cache["_alibi_slopes_require_update"]:
......@@ -941,17 +952,23 @@ def get_alibi(
slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
if _alibi_cache["_alibi_slopes"].dim() == 2:
slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
if bottom_right_alignment:
bias = torch.arange(1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view(
1, 1, 1, max_seqlen_kv
)
else:
bias = torch.arange(
1 - max_seqlen_q, max_seqlen_kv - max_seqlen_q + 1, dtype=torch.int32, device="cuda"
).view(1, 1, 1, max_seqlen_kv)
bias = bias - torch.arange(1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view(
bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
1, 1, 1, max_seqlen_kv
)
if actual_seqlens_q is None and actual_seqlens_kv is None:
if bottom_right_alignment:
bias = bias + max_seqlen_kv - max_seqlen_q
elif actual_seqlens_q is not None and actual_seqlens_kv is not None:
batch_size = actual_seqlens_q.shape[0]
bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
if bottom_right_alignment:
bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
else:
assert (
False
), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!"
bias = bias.abs().mul(-1)
bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
_alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
......@@ -3705,6 +3722,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
def __init__(
self,
softmax_scale: float,
attention_type: str = "self",
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None,
......@@ -3712,6 +3730,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
super().__init__()
self.softmax_scale = softmax_scale
self.attention_type = attention_type
self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number
......@@ -3751,6 +3770,58 @@ class UnfusedDotProductAttention(torch.nn.Module):
query_layer, key_layer, value_layer = [
x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
]
batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[1],
query_layer.shape[0],
key_layer.shape[0],
)
if "padding" in attn_mask_type:
if self.attention_type == "self":
assert attention_mask.shape == (
batch_size,
1,
1,
max_seqlen_q,
), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!"
attention_mask = torch.logical_or(
attention_mask.squeeze(1).unsqueeze(3), attention_mask
)
else:
assert (
len(attention_mask) == 2
and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q)
and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv)
), (
"attention_mask should be a tuple of two tensors with shapes "
"[b, 1, 1, sq] and [b, 1, 1, skv]!"
)
attention_mask = torch.logical_or(
attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
)
mask = attention_mask.squeeze(1).logical_not()
actual_seqlens_q = mask[:, :, 0].sum(dim=1)
actual_seqlens_kv = mask[:, 0, :].sum(dim=1)
mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
1, 1, 1, max_seqlen_kv
)
if attn_mask_type == "padding_causal":
attention_mask = torch.logical_or(
torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0),
attention_mask,
)
if attn_mask_type == "padding_causal_bottom_right":
attention_mask = torch.logical_or(
torch.where(
mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
+ (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
< 0,
1,
0,
),
attention_mask,
)
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
......@@ -3805,7 +3876,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=scale,
)
).view(*output_size)
elif core_attention_bias_type == "pre_scale_bias":
assert core_attention_bias is not None, "core_attention_bias should not be None!"
......@@ -3813,10 +3884,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
)
matmul_result = (
matmul_result.view(output_size[0], output_size[1], output_size[2], output_size[3])
+ core_attention_bias
).view(-1, output_size[2], output_size[3])
matmul_result = matmul_result.view(*output_size) + core_attention_bias
matmul_result *= scale
elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
......@@ -3827,6 +3895,8 @@ class UnfusedDotProductAttention(torch.nn.Module):
output_size[1],
output_size[2],
output_size[3],
actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None,
actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None,
alibi_slopes=alibi_slopes,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
)
......@@ -3837,26 +3907,21 @@ class UnfusedDotProductAttention(torch.nn.Module):
beta=0.0,
alpha=scale,
)
matmul_result = (
(
matmul_result.view(
output_size[0], output_size[1], output_size[2], output_size[3]
)
+ core_attention_bias
)
.view(-1, output_size[2], output_size[3])
.to(dtype=query_layer.dtype)
matmul_result = (matmul_result.view(*output_size) + core_attention_bias).to(
dtype=query_layer.dtype
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# attention scores and attention mask [b, np, sq, sk]
softmax_scale = self.layer_number if apply_qk_layer_scaling else None
attention_probs = self.scale_mask_softmax(
attention_scores, attention_mask, attn_mask_type, softmax_scale
matmul_result, attention_mask, attn_mask_type, softmax_scale
)
# mask out the pad positions in softmax results, mostly for the rows (pad tokens from q)
# the columns (pad tokens from k) are already zeroed out during softmax
if "padding" in attn_mask_type:
attention_probs = attention_probs.masked_fill(attention_mask, 0)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with self.attention_dropout_ctx():
......@@ -6232,7 +6297,10 @@ class DotProductAttention(TransformerEngineBaseModule):
)
self.unfused_attention = UnfusedDotProductAttention(
softmax_scale, **attn_kwargs, layer_number=layer_number
softmax_scale,
attention_type=attention_type,
**attn_kwargs,
layer_number=layer_number,
)
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
......@@ -6522,6 +6590,11 @@ class DotProductAttention(TransformerEngineBaseModule):
if inference_params is not None:
assert self.layer_number is not None, "Layer number must be set!"
# convert causal to causal_bottom_right in inference when KV-caching is in use
# so users can run with the same attn_mask_type for training and inference
if attn_mask_type in ["causal", "padding_causal"]:
attn_mask_type = attn_mask_type + "_bottom_right"
if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)
......@@ -6628,7 +6701,6 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_mask is not None
), "Please provide attention_mask for padding!"
if self.attention_type == "self":
assert max_seqlen_q == max_seqlen_kv
cu_seqlens_q = get_cu_seqlens(attention_mask)
cu_seqlens_kv = cu_seqlens_q
else:
......
......@@ -329,25 +329,22 @@ class FusedScaleMaskSoftmax(nn.Module):
return False # sk must be 16 ~ 16384
if sk % 8 != 0:
return False # sk must be divisor of 8
if self.attn_mask_type == "arbitrary":
return False # Custom masks not supported
if sq == 1:
return False # sq must be > 1
if self.attn_mask_type == "causal" and sq != sk:
return False # Fused causal kernel only support causal_bottom_right
if (
sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
and self.attn_mask_type != "arbitrary" # Custom masks not supported
):
batch_per_block = self.get_batch_per_block(int(sk))
if self.attn_mask_type == "padding":
if "padding" in self.attn_mask_type or self.attn_mask_type == "arbitrary":
if (
mask is not None
and sq % batch_per_block == 0
and mask.shape[-2] == sq
and mask.shape[-1] == sk
and mask.shape[0] in [1, b]
and mask.shape[1:] == (1, sq, sk)
):
return True
else:
......@@ -358,13 +355,21 @@ class FusedScaleMaskSoftmax(nn.Module):
def forward_fused_softmax(
self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
) -> torch.Tensor:
"""Fused masked softmax kernel"""
"""
Fused masked softmax path.
attn_mask_type | module
-----------------------------------------------------------------------------------------
no_mask | ScaledSoftmax
causal (self-attention), causal_bottom_right | ScaledAlignedCausalMaskedSoftmax
padding, padding_causal, padding_causal_bottom_right | ScaledMaskedSoftmax
arbitrary ([1, 1, sq, sk] or [b, 1, sq, sk]) | ScaledMaskedSoftmax
"""
scale = 1.0 if scale is None else scale
if "causal" in self.attn_mask_type:
if self.attn_mask_type in ["causal", "causal_bottom_right"]:
return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale)
# input is 4D tensor (b, np, sq, sk)
# input is 4D tensor (1, 1, sq, sk) or (b, 1, sq, sk)
if mask is not None and self.attn_mask_type != "no_mask":
return ScaledMaskedSoftmax.apply(inp, mask, scale)
return ScaledSoftmax.apply(inp, scale)
......@@ -379,13 +384,19 @@ class FusedScaleMaskSoftmax(nn.Module):
if scale is not None:
inp = inp * scale
if "causal" in self.attn_mask_type:
if self.attn_mask_type in ["causal", "causal_bottom_right"]:
seq_len_q, seq_len_k = inp.size(2), inp.size(3)
if is_in_onnx_export_mode() and self.kvcache_max_seq > 0:
assert self.kvcache_max_seq >= seq_len_k
mask = _get_onnx_export_causal_mask(seq_len_q, seq_len_k, self.onnx_causal_mask)
causal_mask = _get_onnx_export_causal_mask(
seq_len_q, seq_len_k, self.onnx_causal_mask
)
else:
causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k)
if mask is None:
mask = causal_mask
else:
mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k)
mask = torch.logical_or(mask, causal_mask)
mask_output = inp
if mask is not None and self.attn_mask_type != "no_mask":
......
......@@ -624,7 +624,7 @@ class TransformerLayer(torch.nn.Module):
Whether to set output tensors to 0 or not before use.
inference_params: InferenceParams, default = None
Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference.
to efficiently calculate and store the context during inference.
"""
if self_attn_mask_type is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment