Unverified Commit 47ca514a authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Support packed input for FA (#302)



* initial changes [wip]
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add padding mask support for FA
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm causal mask from tests and add padding
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix some conflicts
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* conflicts
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add unpadding mask
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix padding mask
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [wip] fix API
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add packing and unpacking
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* docs fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix atomic_add bf16 torch.compile
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Generate non all True masks
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Lint fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix core attention export and FusedAttn filter
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix all ONNX tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Memory optimization
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Optimizations and caching fixes in torch.dynamo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Padding optimizations
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes and reviews
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d3157e2a
......@@ -612,14 +612,13 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
block = _test_e2e_checkpointing_get_model(config, dtype)
for _ in range(steps // 2):
te_out = block(
te_inp_hidden_states,
te_inp_attn_mask,
None,
)
loss = te_out.sum()
loss.backward()
......@@ -650,7 +649,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
for _ in range(steps // 2):
te_out = block(
te_inp_hidden_states,
te_inp_attn_mask,
None,
)
loss = te_out.sum()
loss.backward()
......
......@@ -316,9 +316,9 @@ def get_attn_mask_str(use_mask, attn_mask_type):
# See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names.
if attn_mask_type is None:
return "_mask" if use_mask else "_no-mask"
attn_mask_str = "_padding-no-mask"
attn_mask_str = "_arbitrary-no-mask"
attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str
attn_mask_str = "_padding-mask" if use_mask and attn_mask_type == "padding" else attn_mask_str
attn_mask_str = "_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str
return attn_mask_str
......@@ -986,14 +986,14 @@ def test_export_layernorm_mlp(
@skip_FP8
@pytest.mark.parametrize(
"precision, use_mask, attn_mask_type", [
(torch.float32, True, "padding"), # calls forward_torch_softmax (apply user mask)
(torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.float16, True, "padding"), # calls forward_torch_softmax (apply user mask)
(torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.bfloat16, True, "padding"), # calls forward_torch_softmax (apply user mask)
(torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
])
def test_export_core_attention(
seed_default_rng,
......@@ -1014,7 +1014,7 @@ def test_export_core_attention(
attention_mask = None
if use_mask:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision)
probs = 0.5 * torch.ones(batch_size, 1, 1, seq_len, device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (query_layer, key_layer, value_layer, attention_mask)
......@@ -1043,9 +1043,8 @@ def test_export_core_attention(
test_configs_multihead_attention = [
#"use_mask, attn_mask_type"
(False, "causal"), # calls ScaledUpperTriangMaskedSoftmax
(True, "padding"), # calls ScaledMaskedSoftmax
(False, "padding"), # calls ScaledSoftmax
(False, "no_mask"), # calls ScaledSoftmax
(True, "arbitrary"), # calls ScaledMaskedSoftmax
]
test_configs_attention_type = [
#"input_layernorm, attention_type, fuse_qkv_params"
......
......@@ -157,18 +157,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
if skip_wgrad:
_disable_wgrads(block)
......@@ -193,18 +182,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
if skip_wgrad:
_disable_wgrads(block)
......@@ -233,18 +211,24 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
def _test_sanity_e2e_bert(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = torch.rand(torch.Size([bs, 1, 1, config.seq_len])).cuda() > 0.5
if skip_wgrad:
_disable_wgrads(block)
......@@ -261,18 +245,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
enc_dec_attn_mask = torch.rand(torch.Size([bs, 1, 1, config.seq_len])).cuda() > 0.5
if skip_wgrad:
_disable_wgrads(block)
......@@ -282,7 +256,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_out = block(
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
encoder_output=te_inp_hidden_states
encoder_output=te_inp_hidden_states,
enc_dec_attn_mask=enc_dec_attn_mask,
)
loss = te_out.sum()
loss.backward()
......@@ -541,13 +516,14 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
self_attn_mask_type="padding",
normalization=normalization,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e_bert(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
......
This diff is collapsed.
......@@ -22,7 +22,7 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16,
}
AttnMaskTypes = ("causal", "padding", "no_mask")
AttnMaskTypes = ("causal", "padding", "arbitrary", "no_mask")
AttnTypes = ("self", "cross")
......
......@@ -261,21 +261,22 @@ class FusedScaleMaskSoftmax(nn.Module):
scale is None or self.softmax_in_fp32
), "softmax should be in fp32 when scaled"
if self.is_kernel_available(*inp.size()) and not is_in_onnx_export_mode():
if self.is_kernel_available(mask, *inp.size()) and not is_in_onnx_export_mode():
return self.forward_fused_softmax(inp, mask, scale)
return self.forward_torch_softmax(inp, mask, scale)
def is_kernel_available(self, b: int, np: int, sq: int, sk: int) -> bool:
def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: int) -> bool:
"""Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches = b * np
if ( # pylint: disable=too-many-boolean-expressions
self.scaled_masked_softmax_fusion # user want to fuse
self.scaled_masked_softmax_fusion # user wants to fuse
and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048
and sk % 8 == 0 # sk must be divisor of 8
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
and self.attn_mask_type != "arbitrary" # Custom masks not supported
):
if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(int(sk))
......@@ -283,6 +284,14 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.attn_mask_type == "causal":
if attn_batches % batch_per_block == 0:
return True
elif self.attn_mask_type == "padding":
if (
mask is not None
and sq % batch_per_block == 0
and mask.shape[-2] == sq
and mask.shape[-1] == sk
):
return True
else:
if sq % batch_per_block == 0:
return True
......@@ -303,7 +312,7 @@ class FusedScaleMaskSoftmax(nn.Module):
probs = ScaledUpperTriangMaskedSoftmax.apply(inp, scale)
return probs.view(b, np, sq, sk)
# input is 4D tensor (b, np, sq, sk)
if mask is not None:
if mask is not None and self.attn_mask_type != "no_mask":
return ScaledMaskedSoftmax.apply(inp, mask, scale)
return ScaledSoftmax.apply(inp, scale)
......@@ -325,7 +334,9 @@ class FusedScaleMaskSoftmax(nn.Module):
else:
mask = _get_default_causal_mask(inp.size(2))
mask_output = self.mask_func(inp, mask) if mask is not None else inp
mask_output = inp
if mask is not None and self.attn_mask_type != "no_mask":
mask_output = self.mask_func(inp, mask)
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
......
......@@ -6,7 +6,7 @@
import os
import warnings
from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
......@@ -127,7 +127,7 @@ class TransformerLayer(torch.nn.Module):
kv_channels: int, default = `None`
number of key-value channels. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
self_attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
self_attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
type of attention mask passed into softmax operation. Overridden by
:attr:`self_attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different
......@@ -429,7 +429,7 @@ class TransformerLayer(torch.nn.Module):
def set_context_parallel_running(
self,
cp_group: Union[dist_group_type, None],
cp_global_ranks: Union[int],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group and CP dual-stream running"""
......@@ -460,16 +460,17 @@ class TransformerLayer(torch.nn.Module):
.. note::
Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type`
is set to `"causal"`.
Argument :attr:`attention_mask` is only used when :attr:`self_attn_mask_type`
includes `"padding"` or `"arbitrary"`.
Parameters
----------
hidden_states : torch.Tensor
Input tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input.
self_attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
attention_mask : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
Boolean tensor used to mask out self-attention softmax input.
Can be a tuple of 2 masks for cross attention with padding masks.
self_attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
type of attention mask passed into softmax operation.
encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment