Unverified Commit 467b39a3 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Add support for padding mask in `UnfusedDotProductAttention` (#1073)



* add support for padding in UnfusedDPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add support for padding_causal/_bottom_right
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix padding_causal/_bottom_right
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* need to test max512 backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix mask logic in unfused
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* use actual_seqlen for alibi/causal_bottom_right padding
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes and convert causal to causal_bottom_right for inference
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* use causal in kv cache inference test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* simplify get_alibi logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* simplify the non-padding path for get_alibi
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* avoid batch_size loop in generating padding_causal/_bottom_right masks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 26c8fcc9
...@@ -1655,8 +1655,8 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -1655,8 +1655,8 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
ffn_hidden_size=4 * D, ffn_hidden_size=4 * D,
num_attention_heads=H, num_attention_heads=H,
attn_input_format=input_format, attn_input_format=input_format,
self_attn_mask_type="causal_bottom_right", self_attn_mask_type="causal",
enc_dec_attn_mask_type="causal_bottom_right", enc_dec_attn_mask_type="causal",
layer_number=layer_number, layer_number=layer_number,
attention_dropout=0.0, attention_dropout=0.0,
params_dtype=dtype, params_dtype=dtype,
...@@ -1670,7 +1670,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -1670,7 +1670,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
qkv_format=input_format, qkv_format=input_format,
layer_number=layer_number, layer_number=layer_number,
attention_dropout=0.0, attention_dropout=0.0,
attn_mask_type="causal_bottom_right", attn_mask_type="causal",
params_dtype=dtype, params_dtype=dtype,
) )
.cuda() .cuda()
......
...@@ -142,7 +142,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -142,7 +142,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
(bias_type == NVTE_Bias_Type::NVTE_ALIBI && (bias_type == NVTE_Bias_Type::NVTE_ALIBI &&
attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && sm_arch_ >= 90) || attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
sm_arch_ >= 90) ||
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) ||
((cudnn_runtime_version >= 90000) && ((cudnn_runtime_version >= 90000) &&
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) && (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) &&
......
...@@ -472,19 +472,25 @@ def get_attention_backend( ...@@ -472,19 +472,25 @@ def get_attention_backend(
use_fused_attention = False use_fused_attention = False
# Filter: Attention mask # Filter: Attention mask
# attn_mask_type | supported backends # attn_mask_type | attention_mask | supported backends
# ------------------------------------------------------------------- # ----------------------------------------------------------------------------------------
# no_mask | All # no_mask | None | All
# padding | FlashAttention, FusedAttention # padding | | All
# causal | # self-attention | One tensor in shape [b, 1, 1, sq] |
# self-attention | All # cross-attention | Tuple of two tensors in shapes |
# cross-attention | FusedAttention # | [b, 1, 1, sq] and [b, 1, 1, skv] |
# padding_causal | # causal | None |
# self-attention | FlashAttention, FusedAttention # self-attention | | All
# cross-attention | FusedAttention # cross-attention | | FusedAttention, UnfusedDotProductAttention
# causal_bottom_right | All # padding_causal | Same as "padding" |
# padding_causal_bottom_right | FlashAttention, FusedAttention # self-attention | | All
# arbitrary | UnfusedDotProductAttention # cross-attention | | FusedAttention, UnfusedDotProductAttention
# causal_bottom_right | None | All
# padding_causal_bottom_right | Same as "padding" |
# self-attention | | All
# cross-attention | | FlashAttention, UnfusedDotProductAttention
# arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention
# | [b, h, sq, skv] |
if attn_mask_type == "arbitrary": if attn_mask_type == "arbitrary":
if use_flash_attention: if use_flash_attention:
logger.debug("Disabling FlashAttention for arbitrary mask") logger.debug("Disabling FlashAttention for arbitrary mask")
...@@ -492,9 +498,6 @@ def get_attention_backend( ...@@ -492,9 +498,6 @@ def get_attention_backend(
if use_fused_attention: if use_fused_attention:
logger.debug("Disabling FusedAttention for arbitrary mask") logger.debug("Disabling FusedAttention for arbitrary mask")
use_fused_attention = False use_fused_attention = False
if use_unfused_attention and "padding" in attn_mask_type:
logger.debug("Disabling UnfusedDotProductAttention for %s mask", attn_mask_type)
use_unfused_attention = False
if ( if (
use_flash_attention use_flash_attention
and _flash_attn_2_1_plus and _flash_attn_2_1_plus
...@@ -780,7 +783,7 @@ def get_attention_backend( ...@@ -780,7 +783,7 @@ def get_attention_backend(
class InferenceParams: # pylint: disable=too-few-public-methods class InferenceParams: # pylint: disable=too-few-public-methods
""" """
Inference parameters that are passed to the main model in order Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference. to efficiently calculate and store the context during inference.
Parameters Parameters
---------- ----------
...@@ -886,6 +889,8 @@ def get_alibi( ...@@ -886,6 +889,8 @@ def get_alibi(
num_heads: int, num_heads: int,
max_seqlen_q: int, max_seqlen_q: int,
max_seqlen_kv: int, max_seqlen_kv: int,
actual_seqlens_q: Optional[torch.Tensor] = None,
actual_seqlens_kv: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
bias_dtype: Optional[torch.dtype] = None, bias_dtype: Optional[torch.dtype] = None,
bottom_right_alignment: bool = True, bottom_right_alignment: bool = True,
...@@ -899,6 +904,10 @@ def get_alibi( ...@@ -899,6 +904,10 @@ def get_alibi(
Maximum sequence length for queries. Maximum sequence length for queries.
max_seqlen_kv: int max_seqlen_kv: int
Maximum sequence length for keys and values. Maximum sequence length for keys and values.
actual_seqlens_q: Optional[torch.Tensor], default = `None`
Actual sequence lengths for queries, in shape [batch_size].
actual_seqlens_kv: Optional[torch.Tensor], default = `None`
Actual sequence lengths for keys and values, in shape [batch_size].
alibi_slopes: Optional[torch.Tensor], default = `None` alibi_slopes: Optional[torch.Tensor], default = `None`
Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads]. Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
bias_dtype: Optional[torch.dtype], default = `None` bias_dtype: Optional[torch.dtype], default = `None`
...@@ -912,10 +921,12 @@ def get_alibi( ...@@ -912,10 +921,12 @@ def get_alibi(
alibi_slopes: torch.Tensor alibi_slopes: torch.Tensor
ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads]. ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
alibi_bias: torch.Tensor alibi_bias: torch.Tensor
ALiBi bias in FP32 or `bias_dtype`. If `alibi_slopes` is in [num_heads] shape, ALiBi bias in FP32 or `bias_dtype`. Its shape is
then `alibi_bias` is in [1, num_heads, max_seqlen_q, max_seqlen_kv], and if (1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape,
`alibi_slopes` is in [batch_size, num_heads], then the bias is in and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or
[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. (2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in
[batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and
`actual_seqlens_q` and `actual_seqlens_kv` are not `None`.
""" """
global _alibi_cache global _alibi_cache
if _alibi_cache["_alibi_slopes_require_update"]: if _alibi_cache["_alibi_slopes_require_update"]:
...@@ -941,17 +952,23 @@ def get_alibi( ...@@ -941,17 +952,23 @@ def get_alibi(
slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1]) slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
if _alibi_cache["_alibi_slopes"].dim() == 2: if _alibi_cache["_alibi_slopes"].dim() == 2:
slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1]) slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
if bottom_right_alignment: bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
bias = torch.arange(1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view( 1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
1, 1, 1, max_seqlen_kv 1, 1, 1, max_seqlen_kv
) )
if actual_seqlens_q is None and actual_seqlens_kv is None:
if bottom_right_alignment:
bias = bias + max_seqlen_kv - max_seqlen_q
elif actual_seqlens_q is not None and actual_seqlens_kv is not None:
batch_size = actual_seqlens_q.shape[0]
bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
if bottom_right_alignment:
bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
else: else:
bias = torch.arange( assert (
1 - max_seqlen_q, max_seqlen_kv - max_seqlen_q + 1, dtype=torch.int32, device="cuda" False
).view(1, 1, 1, max_seqlen_kv) ), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!"
bias = bias - torch.arange(1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
)
bias = bias.abs().mul(-1) bias = bias.abs().mul(-1)
bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape) bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
_alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
...@@ -3705,6 +3722,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -3705,6 +3722,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
softmax_scale: float, softmax_scale: float,
attention_type: str = "self",
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext, attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
...@@ -3712,6 +3730,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -3712,6 +3730,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
super().__init__() super().__init__()
self.softmax_scale = softmax_scale self.softmax_scale = softmax_scale
self.attention_type = attention_type
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number self.layer_number = layer_number
...@@ -3751,6 +3770,58 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -3751,6 +3770,58 @@ class UnfusedDotProductAttention(torch.nn.Module):
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [
x.transpose(0, 1) for x in [query_layer, key_layer, value_layer] x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
] ]
batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[1],
query_layer.shape[0],
key_layer.shape[0],
)
if "padding" in attn_mask_type:
if self.attention_type == "self":
assert attention_mask.shape == (
batch_size,
1,
1,
max_seqlen_q,
), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!"
attention_mask = torch.logical_or(
attention_mask.squeeze(1).unsqueeze(3), attention_mask
)
else:
assert (
len(attention_mask) == 2
and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q)
and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv)
), (
"attention_mask should be a tuple of two tensors with shapes "
"[b, 1, 1, sq] and [b, 1, 1, skv]!"
)
attention_mask = torch.logical_or(
attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
)
mask = attention_mask.squeeze(1).logical_not()
actual_seqlens_q = mask[:, :, 0].sum(dim=1)
actual_seqlens_kv = mask[:, 0, :].sum(dim=1)
mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
1, 1, 1, max_seqlen_kv
)
if attn_mask_type == "padding_causal":
attention_mask = torch.logical_or(
torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0),
attention_mask,
)
if attn_mask_type == "padding_causal_bottom_right":
attention_mask = torch.logical_or(
torch.where(
mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
+ (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
< 0,
1,
0,
),
attention_mask,
)
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
...@@ -3805,7 +3876,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -3805,7 +3876,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, beta=0.0,
alpha=scale, alpha=scale,
) ).view(*output_size)
elif core_attention_bias_type == "pre_scale_bias": elif core_attention_bias_type == "pre_scale_bias":
assert core_attention_bias is not None, "core_attention_bias should not be None!" assert core_attention_bias is not None, "core_attention_bias should not be None!"
...@@ -3813,10 +3884,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -3813,10 +3884,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
query_layer.transpose(0, 1), # [b * np, sq, hn] query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
) )
matmul_result = ( matmul_result = matmul_result.view(*output_size) + core_attention_bias
matmul_result.view(output_size[0], output_size[1], output_size[2], output_size[3])
+ core_attention_bias
).view(-1, output_size[2], output_size[3])
matmul_result *= scale matmul_result *= scale
elif core_attention_bias_type in ["post_scale_bias", "alibi"]: elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
...@@ -3827,6 +3895,8 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -3827,6 +3895,8 @@ class UnfusedDotProductAttention(torch.nn.Module):
output_size[1], output_size[1],
output_size[2], output_size[2],
output_size[3], output_size[3],
actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None,
actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
) )
...@@ -3837,26 +3907,21 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -3837,26 +3907,21 @@ class UnfusedDotProductAttention(torch.nn.Module):
beta=0.0, beta=0.0,
alpha=scale, alpha=scale,
) )
matmul_result = ( matmul_result = (matmul_result.view(*output_size) + core_attention_bias).to(
( dtype=query_layer.dtype
matmul_result.view(
output_size[0], output_size[1], output_size[2], output_size[3]
)
+ core_attention_bias
)
.view(-1, output_size[2], output_size[3])
.to(dtype=query_layer.dtype)
) )
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# attention scores and attention mask [b, np, sq, sk] # attention scores and attention mask [b, np, sq, sk]
softmax_scale = self.layer_number if apply_qk_layer_scaling else None softmax_scale = self.layer_number if apply_qk_layer_scaling else None
attention_probs = self.scale_mask_softmax( attention_probs = self.scale_mask_softmax(
attention_scores, attention_mask, attn_mask_type, softmax_scale matmul_result, attention_mask, attn_mask_type, softmax_scale
) )
# mask out the pad positions in softmax results, mostly for the rows (pad tokens from q)
# the columns (pad tokens from k) are already zeroed out during softmax
if "padding" in attn_mask_type:
attention_probs = attention_probs.masked_fill(attention_mask, 0)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
...@@ -6232,7 +6297,10 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -6232,7 +6297,10 @@ class DotProductAttention(TransformerEngineBaseModule):
) )
self.unfused_attention = UnfusedDotProductAttention( self.unfused_attention = UnfusedDotProductAttention(
softmax_scale, **attn_kwargs, layer_number=layer_number softmax_scale,
attention_type=attention_type,
**attn_kwargs,
layer_number=layer_number,
) )
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
...@@ -6522,6 +6590,11 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -6522,6 +6590,11 @@ class DotProductAttention(TransformerEngineBaseModule):
if inference_params is not None: if inference_params is not None:
assert self.layer_number is not None, "Layer number must be set!" assert self.layer_number is not None, "Layer number must be set!"
# convert causal to causal_bottom_right in inference when KV-caching is in use
# so users can run with the same attn_mask_type for training and inference
if attn_mask_type in ["causal", "padding_causal"]:
attn_mask_type = attn_mask_type + "_bottom_right"
if qkv_format == "bshd": if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1) key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1) value_layer = value_layer.transpose(0, 1)
...@@ -6628,7 +6701,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -6628,7 +6701,6 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_mask is not None attention_mask is not None
), "Please provide attention_mask for padding!" ), "Please provide attention_mask for padding!"
if self.attention_type == "self": if self.attention_type == "self":
assert max_seqlen_q == max_seqlen_kv
cu_seqlens_q = get_cu_seqlens(attention_mask) cu_seqlens_q = get_cu_seqlens(attention_mask)
cu_seqlens_kv = cu_seqlens_q cu_seqlens_kv = cu_seqlens_q
else: else:
......
...@@ -329,25 +329,22 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -329,25 +329,22 @@ class FusedScaleMaskSoftmax(nn.Module):
return False # sk must be 16 ~ 16384 return False # sk must be 16 ~ 16384
if sk % 8 != 0: if sk % 8 != 0:
return False # sk must be divisor of 8 return False # sk must be divisor of 8
if self.attn_mask_type == "arbitrary": if sq == 1:
return False # Custom masks not supported return False # sq must be > 1
if self.attn_mask_type == "causal" and sq != sk: if self.attn_mask_type == "causal" and sq != sk:
return False # Fused causal kernel only support causal_bottom_right return False # Fused causal kernel only support causal_bottom_right
if ( if (
sq % 4 == 0 # sq must be divisor of 4 sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4
and self.attn_mask_type != "arbitrary" # Custom masks not supported
): ):
batch_per_block = self.get_batch_per_block(int(sk)) batch_per_block = self.get_batch_per_block(int(sk))
if "padding" in self.attn_mask_type or self.attn_mask_type == "arbitrary":
if self.attn_mask_type == "padding":
if ( if (
mask is not None mask is not None
and sq % batch_per_block == 0 and sq % batch_per_block == 0
and mask.shape[-2] == sq and mask.shape[0] in [1, b]
and mask.shape[-1] == sk and mask.shape[1:] == (1, sq, sk)
): ):
return True return True
else: else:
...@@ -358,13 +355,21 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -358,13 +355,21 @@ class FusedScaleMaskSoftmax(nn.Module):
def forward_fused_softmax( def forward_fused_softmax(
self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
) -> torch.Tensor: ) -> torch.Tensor:
"""Fused masked softmax kernel""" """
Fused masked softmax path.
attn_mask_type | module
-----------------------------------------------------------------------------------------
no_mask | ScaledSoftmax
causal (self-attention), causal_bottom_right | ScaledAlignedCausalMaskedSoftmax
padding, padding_causal, padding_causal_bottom_right | ScaledMaskedSoftmax
arbitrary ([1, 1, sq, sk] or [b, 1, sq, sk]) | ScaledMaskedSoftmax
"""
scale = 1.0 if scale is None else scale scale = 1.0 if scale is None else scale
if "causal" in self.attn_mask_type: if self.attn_mask_type in ["causal", "causal_bottom_right"]:
return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale) return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale)
# input is 4D tensor (b, np, sq, sk) # input is 4D tensor (1, 1, sq, sk) or (b, 1, sq, sk)
if mask is not None and self.attn_mask_type != "no_mask": if mask is not None and self.attn_mask_type != "no_mask":
return ScaledMaskedSoftmax.apply(inp, mask, scale) return ScaledMaskedSoftmax.apply(inp, mask, scale)
return ScaledSoftmax.apply(inp, scale) return ScaledSoftmax.apply(inp, scale)
...@@ -379,13 +384,19 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -379,13 +384,19 @@ class FusedScaleMaskSoftmax(nn.Module):
if scale is not None: if scale is not None:
inp = inp * scale inp = inp * scale
if "causal" in self.attn_mask_type: if self.attn_mask_type in ["causal", "causal_bottom_right"]:
seq_len_q, seq_len_k = inp.size(2), inp.size(3) seq_len_q, seq_len_k = inp.size(2), inp.size(3)
if is_in_onnx_export_mode() and self.kvcache_max_seq > 0: if is_in_onnx_export_mode() and self.kvcache_max_seq > 0:
assert self.kvcache_max_seq >= seq_len_k assert self.kvcache_max_seq >= seq_len_k
mask = _get_onnx_export_causal_mask(seq_len_q, seq_len_k, self.onnx_causal_mask) causal_mask = _get_onnx_export_causal_mask(
seq_len_q, seq_len_k, self.onnx_causal_mask
)
else:
causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k)
if mask is None:
mask = causal_mask
else: else:
mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k) mask = torch.logical_or(mask, causal_mask)
mask_output = inp mask_output = inp
if mask is not None and self.attn_mask_type != "no_mask": if mask is not None and self.attn_mask_type != "no_mask":
......
...@@ -624,7 +624,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -624,7 +624,7 @@ class TransformerLayer(torch.nn.Module):
Whether to set output tensors to 0 or not before use. Whether to set output tensors to 0 or not before use.
inference_params: InferenceParams, default = None inference_params: InferenceParams, default = None
Inference parameters that are passed to the main model in order Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference. to efficiently calculate and store the context during inference.
""" """
if self_attn_mask_type is None: if self_attn_mask_type is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment