Unverified Commit 47ca514a authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Support packed input for FA (#302)



* initial changes [wip]
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add padding mask support for FA
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm causal mask from tests and add padding
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix some conflicts
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* conflicts
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add unpadding mask
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix padding mask
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [wip] fix API
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add packing and unpacking
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* docs fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix atomic_add bf16 torch.compile
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Generate non all True masks
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Lint fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix core attention export and FusedAttn filter
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix all ONNX tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Memory optimization
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Optimizations and caching fixes in torch.dynamo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Padding optimizations
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes and reviews
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d3157e2a
...@@ -612,14 +612,13 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= ...@@ -612,14 +612,13 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
block = _test_e2e_checkpointing_get_model(config, dtype) block = _test_e2e_checkpointing_get_model(config, dtype)
for _ in range(steps // 2): for _ in range(steps // 2):
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_hidden_states,
te_inp_attn_mask, None,
) )
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
...@@ -650,7 +649,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= ...@@ -650,7 +649,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
for _ in range(steps // 2): for _ in range(steps // 2):
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_hidden_states,
te_inp_attn_mask, None,
) )
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
......
...@@ -316,9 +316,9 @@ def get_attn_mask_str(use_mask, attn_mask_type): ...@@ -316,9 +316,9 @@ def get_attn_mask_str(use_mask, attn_mask_type):
# See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names. # See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names.
if attn_mask_type is None: if attn_mask_type is None:
return "_mask" if use_mask else "_no-mask" return "_mask" if use_mask else "_no-mask"
attn_mask_str = "_padding-no-mask" attn_mask_str = "_arbitrary-no-mask"
attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str
attn_mask_str = "_padding-mask" if use_mask and attn_mask_type == "padding" else attn_mask_str attn_mask_str = "_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str
return attn_mask_str return attn_mask_str
...@@ -986,13 +986,13 @@ def test_export_layernorm_mlp( ...@@ -986,13 +986,13 @@ def test_export_layernorm_mlp(
@skip_FP8 @skip_FP8
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision, use_mask, attn_mask_type", [ "precision, use_mask, attn_mask_type", [
(torch.float32, True, "padding"), # calls forward_torch_softmax (apply user mask) (torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask) (torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) (torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.float16, True, "padding"), # calls forward_torch_softmax (apply user mask) (torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) (torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) (torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.bfloat16, True, "padding"), # calls forward_torch_softmax (apply user mask) (torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) (torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
]) ])
def test_export_core_attention( def test_export_core_attention(
...@@ -1014,7 +1014,7 @@ def test_export_core_attention( ...@@ -1014,7 +1014,7 @@ def test_export_core_attention(
attention_mask = None attention_mask = None
if use_mask: if use_mask:
# Generate a random mask with 50% probability for 0 or 1. # Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision) probs = 0.5 * torch.ones(batch_size, 1, 1, seq_len, device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (query_layer, key_layer, value_layer, attention_mask) inp = (query_layer, key_layer, value_layer, attention_mask)
...@@ -1043,9 +1043,8 @@ def test_export_core_attention( ...@@ -1043,9 +1043,8 @@ def test_export_core_attention(
test_configs_multihead_attention = [ test_configs_multihead_attention = [
#"use_mask, attn_mask_type" #"use_mask, attn_mask_type"
(False, "causal"), # calls ScaledUpperTriangMaskedSoftmax (False, "no_mask"), # calls ScaledSoftmax
(True, "padding"), # calls ScaledMaskedSoftmax (True, "arbitrary"), # calls ScaledMaskedSoftmax
(False, "padding"), # calls ScaledSoftmax
] ]
test_configs_attention_type = [ test_configs_attention_type = [
#"input_layernorm, attention_type, fuse_qkv_params" #"input_layernorm, attention_type, fuse_qkv_params"
......
...@@ -157,18 +157,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -157,18 +157,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda() ).cuda()
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = ( te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
...@@ -193,18 +182,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_ ...@@ -193,18 +182,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
te_inp_attn_mask = ( te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
...@@ -233,18 +211,24 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -233,18 +211,24 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
te_inp_attn_mask = (
torch.rand( if skip_wgrad:
( _disable_wgrads(block)
1,
1, use_fp8 = fp8_recipe is not None
config.seq_len, with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
config.seq_len, te_out = block(te_inp_hidden_states)
) loss = te_out.sum()
) loss.backward()
.cuda() torch.cuda.synchronize()
.bool()
)
def _test_sanity_e2e_bert(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = torch.rand(torch.Size([bs, 1, 1, config.seq_len])).cuda() > 0.5
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
...@@ -261,18 +245,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -261,18 +245,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
te_inp_attn_mask = ( te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
torch.rand( enc_dec_attn_mask = torch.rand(torch.Size([bs, 1, 1, config.seq_len])).cuda() > 0.5
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
...@@ -282,7 +256,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -282,7 +256,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_hidden_states,
attention_mask=te_inp_attn_mask, attention_mask=te_inp_attn_mask,
encoder_output=te_inp_hidden_states encoder_output=te_inp_hidden_states,
enc_dec_attn_mask=enc_dec_attn_mask,
) )
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
...@@ -541,13 +516,14 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam ...@@ -541,13 +516,14 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam
apply_residual_connection_post_layernorm=True, apply_residual_connection_post_layernorm=True,
output_layernorm=True, output_layernorm=True,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
self_attn_mask_type="padding",
normalization=normalization, normalization=normalization,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_e2e_bert(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
......
This diff is collapsed.
...@@ -22,7 +22,7 @@ TE_DType = { ...@@ -22,7 +22,7 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16, torch.bfloat16: tex.DType.kBFloat16,
} }
AttnMaskTypes = ("causal", "padding", "no_mask") AttnMaskTypes = ("causal", "padding", "arbitrary", "no_mask")
AttnTypes = ("self", "cross") AttnTypes = ("self", "cross")
......
...@@ -261,21 +261,22 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -261,21 +261,22 @@ class FusedScaleMaskSoftmax(nn.Module):
scale is None or self.softmax_in_fp32 scale is None or self.softmax_in_fp32
), "softmax should be in fp32 when scaled" ), "softmax should be in fp32 when scaled"
if self.is_kernel_available(*inp.size()) and not is_in_onnx_export_mode(): if self.is_kernel_available(mask, *inp.size()) and not is_in_onnx_export_mode():
return self.forward_fused_softmax(inp, mask, scale) return self.forward_fused_softmax(inp, mask, scale)
return self.forward_torch_softmax(inp, mask, scale) return self.forward_torch_softmax(inp, mask, scale)
def is_kernel_available(self, b: int, np: int, sq: int, sk: int) -> bool: def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: int) -> bool:
"""Check FusedScaleMaskSoftmax kernel availability based on size""" """Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches = b * np attn_batches = b * np
if ( # pylint: disable=too-many-boolean-expressions if ( # pylint: disable=too-many-boolean-expressions
self.scaled_masked_softmax_fusion # user want to fuse self.scaled_masked_softmax_fusion # user wants to fuse
and self.input_in_float16 # input must be fp16 and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048 and 16 < sk <= 4096 # sk must be 16 ~ 2048
and sk % 8 == 0 # sk must be divisor of 8 and sk % 8 == 0 # sk must be divisor of 8
and sq % 4 == 0 # sq must be divisor of 4 and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4
and self.attn_mask_type != "arbitrary" # Custom masks not supported
): ):
if 0 <= sk <= 4096: if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(int(sk)) batch_per_block = self.get_batch_per_block(int(sk))
...@@ -283,6 +284,14 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -283,6 +284,14 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.attn_mask_type == "causal": if self.attn_mask_type == "causal":
if attn_batches % batch_per_block == 0: if attn_batches % batch_per_block == 0:
return True return True
elif self.attn_mask_type == "padding":
if (
mask is not None
and sq % batch_per_block == 0
and mask.shape[-2] == sq
and mask.shape[-1] == sk
):
return True
else: else:
if sq % batch_per_block == 0: if sq % batch_per_block == 0:
return True return True
...@@ -303,7 +312,7 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -303,7 +312,7 @@ class FusedScaleMaskSoftmax(nn.Module):
probs = ScaledUpperTriangMaskedSoftmax.apply(inp, scale) probs = ScaledUpperTriangMaskedSoftmax.apply(inp, scale)
return probs.view(b, np, sq, sk) return probs.view(b, np, sq, sk)
# input is 4D tensor (b, np, sq, sk) # input is 4D tensor (b, np, sq, sk)
if mask is not None: if mask is not None and self.attn_mask_type != "no_mask":
return ScaledMaskedSoftmax.apply(inp, mask, scale) return ScaledMaskedSoftmax.apply(inp, mask, scale)
return ScaledSoftmax.apply(inp, scale) return ScaledSoftmax.apply(inp, scale)
...@@ -325,7 +334,9 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -325,7 +334,9 @@ class FusedScaleMaskSoftmax(nn.Module):
else: else:
mask = _get_default_causal_mask(inp.size(2)) mask = _get_default_causal_mask(inp.size(2))
mask_output = self.mask_func(inp, mask) if mask is not None else inp mask_output = inp
if mask is not None and self.attn_mask_type != "no_mask":
mask_output = self.mask_func(inp, mask)
probs = torch.nn.Softmax(dim=-1)(mask_output) probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32: if self.input_in_float16 and self.softmax_in_fp32:
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import os import os
import warnings import warnings
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
import torch import torch
...@@ -127,7 +127,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -127,7 +127,7 @@ class TransformerLayer(torch.nn.Module):
kv_channels: int, default = `None` kv_channels: int, default = `None`
number of key-value channels. defaults to number of key-value channels. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`. :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
self_attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` self_attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
type of attention mask passed into softmax operation. Overridden by type of attention mask passed into softmax operation. Overridden by
:attr:`self_attn_mask_type` in the `forward` method. The forward :attr:`self_attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different arg is useful for dynamically changing mask types, e.g. a different
...@@ -429,7 +429,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -429,7 +429,7 @@ class TransformerLayer(torch.nn.Module):
def set_context_parallel_running( def set_context_parallel_running(
self, self,
cp_group: Union[dist_group_type, None], cp_group: Union[dist_group_type, None],
cp_global_ranks: Union[int], cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream, cp_stream: torch.cuda.Stream,
) -> None: ) -> None:
"""Set CP group and CP dual-stream running""" """Set CP group and CP dual-stream running"""
...@@ -460,16 +460,17 @@ class TransformerLayer(torch.nn.Module): ...@@ -460,16 +460,17 @@ class TransformerLayer(torch.nn.Module):
.. note:: .. note::
Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type` Argument :attr:`attention_mask` is only used when :attr:`self_attn_mask_type`
is set to `"causal"`. includes `"padding"` or `"arbitrary"`.
Parameters Parameters
---------- ----------
hidden_states : torch.Tensor hidden_states : torch.Tensor
Input tensor. Input tensor.
attention_mask : Optional[torch.Tensor], default = `None` attention_mask : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
Boolean tensor used to mask out self-attention softmax input. Boolean tensor used to mask out self-attention softmax input.
self_attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` Can be a tuple of 2 masks for cross attention with padding masks.
self_attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
type of attention mask passed into softmax operation. type of attention mask passed into softmax operation.
encoder_output : Optional[torch.Tensor], default = `None` encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment