Unverified Commit 47ca514a authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Support packed input for FA (#302)



* initial changes [wip]
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add padding mask support for FA
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm causal mask from tests and add padding
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix some conflicts
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* conflicts
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add unpadding mask
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix padding mask
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [wip] fix API
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add packing and unpacking
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* docs fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix atomic_add bf16 torch.compile
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Generate non all True masks
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Lint fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix core attention export and FusedAttn filter
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix all ONNX tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Memory optimization
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Optimizations and caching fixes in torch.dynamo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Padding optimizations
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes and reviews
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d3157e2a
......@@ -612,14 +612,13 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
block = _test_e2e_checkpointing_get_model(config, dtype)
for _ in range(steps // 2):
te_out = block(
te_inp_hidden_states,
te_inp_attn_mask,
None,
)
loss = te_out.sum()
loss.backward()
......@@ -650,7 +649,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
for _ in range(steps // 2):
te_out = block(
te_inp_hidden_states,
te_inp_attn_mask,
None,
)
loss = te_out.sum()
loss.backward()
......
......@@ -316,9 +316,9 @@ def get_attn_mask_str(use_mask, attn_mask_type):
# See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names.
if attn_mask_type is None:
return "_mask" if use_mask else "_no-mask"
attn_mask_str = "_padding-no-mask"
attn_mask_str = "_arbitrary-no-mask"
attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str
attn_mask_str = "_padding-mask" if use_mask and attn_mask_type == "padding" else attn_mask_str
attn_mask_str = "_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str
return attn_mask_str
......@@ -986,14 +986,14 @@ def test_export_layernorm_mlp(
@skip_FP8
@pytest.mark.parametrize(
"precision, use_mask, attn_mask_type", [
(torch.float32, True, "padding"), # calls forward_torch_softmax (apply user mask)
(torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.float16, True, "padding"), # calls forward_torch_softmax (apply user mask)
(torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.bfloat16, True, "padding"), # calls forward_torch_softmax (apply user mask)
(torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
])
def test_export_core_attention(
seed_default_rng,
......@@ -1014,7 +1014,7 @@ def test_export_core_attention(
attention_mask = None
if use_mask:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision)
probs = 0.5 * torch.ones(batch_size, 1, 1, seq_len, device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (query_layer, key_layer, value_layer, attention_mask)
......@@ -1043,9 +1043,8 @@ def test_export_core_attention(
test_configs_multihead_attention = [
#"use_mask, attn_mask_type"
(False, "causal"), # calls ScaledUpperTriangMaskedSoftmax
(True, "padding"), # calls ScaledMaskedSoftmax
(False, "padding"), # calls ScaledSoftmax
(False, "no_mask"), # calls ScaledSoftmax
(True, "arbitrary"), # calls ScaledMaskedSoftmax
]
test_configs_attention_type = [
#"input_layernorm, attention_type, fuse_qkv_params"
......
......@@ -157,18 +157,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
if skip_wgrad:
_disable_wgrads(block)
......@@ -193,18 +182,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
if skip_wgrad:
_disable_wgrads(block)
......@@ -233,18 +211,24 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
def _test_sanity_e2e_bert(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = torch.rand(torch.Size([bs, 1, 1, config.seq_len])).cuda() > 0.5
if skip_wgrad:
_disable_wgrads(block)
......@@ -261,18 +245,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
enc_dec_attn_mask = torch.rand(torch.Size([bs, 1, 1, config.seq_len])).cuda() > 0.5
if skip_wgrad:
_disable_wgrads(block)
......@@ -282,7 +256,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_out = block(
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
encoder_output=te_inp_hidden_states
encoder_output=te_inp_hidden_states,
enc_dec_attn_mask=enc_dec_attn_mask,
)
loss = te_out.sum()
loss.backward()
......@@ -541,13 +516,14 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
self_attn_mask_type="padding",
normalization=normalization,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e_bert(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
......
......@@ -8,11 +8,12 @@ import warnings
import math
from importlib.metadata import version
from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union, Dict, List
from typing import Any, Callable, List, Optional, Tuple, Union, Dict
from pkg_resources import packaging
import numpy as np
import torch
import torch.nn.functional as F
import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
......@@ -50,6 +51,7 @@ from transformer_engine.pytorch.distributed import (
checkpoint,
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("1.0.6")
......@@ -65,9 +67,210 @@ else:
from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward
_cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv = None, None, None, None
__all__ = ["DotProductAttention", "MultiheadAttention"]
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
tensor of shape [batch_size + 1,] containing the cumulative sequence
lengths of every sample in the batch and the indices containing valid
samples.
"""
mask = mask.squeeze(1).squeeze(1)
bs, seqlen = mask.shape
reduced_mask = mask.sum(dim=1)
cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
zero = torch.zeros(1, dtype=torch.int32, device="cuda")
cu_seqlens = torch.cat((zero, cu_seqlens))
mask = mask.reshape(-1)
indices = mask.nonzero()
indices = indices.unsqueeze(-1)
num_nonzeros = indices.shape[0]
pad_amount = bs * seqlen - num_nonzeros
indices = F.pad(input=indices, pad=(0, 0, 0, 0, 0, pad_amount),
mode="constant", value=float(bs * seqlen))
return cu_seqlens, indices
@jit_fuser
def pack_tensor(
indices: torch.Tensor,
tensor: torch.Tensor,
) -> torch.Tensor:
"""
Packs the given tensor using the `indices`.
"""
padding_indice = torch.zeros(
1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device)
tensor = torch.cat((tensor, padding_indice), dim=0)
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
packed = torch.gather(tensor, 0, indices)
return packed
@jit_fuser
def pack_2_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Packs the given 2 tensors using the `indices`.
"""
t1_packed = pack_tensor(indices, t1)
t2_packed = pack_tensor(indices, t2)
return t1_packed, t2_packed
@jit_fuser
def pack_3_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
t2: torch.Tensor,
t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Packs the given 3 tensors using the `indices`.
"""
t1_packed = pack_tensor(indices, t1)
t2_packed = pack_tensor(indices, t2)
t3_packed = pack_tensor(indices, t3)
return t1_packed, t2_packed, t3_packed
@jit_fuser
def unpack_tensor(
indices: torch.Tensor,
dim0: int,
tensor: torch.Tensor,
) -> torch.Tensor:
"""
Inverse of `pack_tensor`.
"""
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
unpacked = torch.zeros(
dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device)
unpacked.scatter_(0, indices, tensor)
unpacked = unpacked[0:-1,:,:]
return unpacked
@jit_fuser
def unpack_2_tensors(
indices: torch.Tensor,
dim0: int,
t1: torch.Tensor,
t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Inverse of `pack_2_tensors`.
"""
t1_unpacked = unpack_tensor(indices, dim0, t1)
t2_unpacked = unpack_tensor(indices, dim0, t2)
return t1_unpacked, t2_unpacked
@jit_fuser
def unpack_3_tensors(
indices: torch.Tensor,
dim0: int,
t1: torch.Tensor,
t2: torch.Tensor,
t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Inverse of `pack_3_tensors`.
"""
t1_unpacked = unpack_tensor(indices, dim0, t1)
t2_unpacked = unpack_tensor(indices, dim0, t2)
t3_unpacked = unpack_tensor(indices, dim0, t3)
return t1_unpacked, t2_unpacked, t3_unpacked
class PackTensors(torch.autograd.Function):
"""
Autograd function to pack tensors.
"""
@staticmethod
def forward(
ctx,
indices: torch.Tensor,
*tensors: Tuple[torch.Tensor, ...]
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
ctx.indices = indices
ctx.dim0 = tensors[0].shape[0]
if len(tensors) == 1:
return pack_tensor(indices, *tensors)
if len(tensors) == 2:
return pack_2_tensors(indices, *tensors)
return pack_3_tensors(indices, *tensors)
@staticmethod
def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
if len(grad_outputs) == 1:
return None, unpack_tensor(ctx.indices, ctx.dim0, *grad_outputs)
if len(grad_outputs) == 2:
return None, *unpack_2_tensors(ctx.indices, ctx.dim0, *grad_outputs)
return None, *unpack_3_tensors(ctx.indices, ctx.dim0, *grad_outputs)
class UnpackTensor(torch.autograd.Function):
"""
Autograd function to unpack a tensor.
"""
@staticmethod
def forward(
ctx,
indices: torch.Tensor,
dim0: int,
tensor: torch.Tensor,
) -> torch.Tensor:
ctx.indices = indices
return unpack_tensor(indices, dim0, tensor)
@staticmethod
def backward(ctx, grad_output):
return None, None, pack_tensor(ctx.indices, grad_output)
def _unpack_attn_mask_type(attn_mask_type: str) -> Tuple[str, bool]:
"""
Unpacks the attention mask type string and returns a single mask type
and a boolean for whether to apply causal mask. Also ensures that the
combination of masks passed in is supported by one of the attention
backends available.
"""
mask_types = attn_mask_type.split(',')
assert (
all(mask_type in AttnMaskTypes for mask_type in mask_types)
), f"Mask type {attn_mask_type} is not supported."
# Whether or not to apply causal mask toggle.
causal_mask = False
if "causal" in mask_types:
mask_types.remove("causal")
causal_mask = True
if len(mask_types) == 0: # Only apply causal mask.
return "causal", True
if len(mask_types) == 1 and causal_mask: # Causal + padding masks
assert mask_types[0] == "padding", f"Causal + {mask_types[0]} masking not supported."
return "padding", True
if len(mask_types) == 1: # Arbitrary or padding or no_mask
return mask_types[0], False
raise RuntimeError("Unsupported combination of mask types.")
def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
recv_tensor, recv_src,
cp_group, batch_p2p_comm):
......@@ -608,7 +811,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
attn_mask_type: str = "causal",
attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
......@@ -900,6 +1103,8 @@ class FlashAttention(torch.nn.Module):
norm_factor: float,
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
attention_type: str = "self",
layer_number: Optional[int] = None,
deterministic: bool = False,
) -> None:
super().__init__()
......@@ -911,6 +1116,8 @@ class FlashAttention(torch.nn.Module):
self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx
self.attention_dropout = attention_dropout
self.attention_type = attention_type
self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic
def forward(
......@@ -918,12 +1125,13 @@ class FlashAttention(torch.nn.Module):
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
qkv_layout: str = "sbh3d",
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
attn_mask_type: str = "causal",
cp_group: Optional[dist_group_type] = None,
cp_global_ranks: Union[int] = None,
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
) -> torch.Tensor:
"""flash-attn fprop"""
......@@ -940,6 +1148,8 @@ class FlashAttention(torch.nn.Module):
qkv_layout in QKVLayouts
), f"FlashAttention does not support qkv_layout = {qkv_layout}!"
context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
if qkv_format == 'sbhd':
......@@ -953,32 +1163,63 @@ class FlashAttention(torch.nn.Module):
else:
query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
for x in (query_layer, key_layer, value_layer)]
if qkv_format == 'bshd':
elif qkv_format == 'bshd':
query_layer, key_layer, value_layer = [x.contiguous()
for x in (query_layer, key_layer, value_layer)]
if qkv_format in ['sbhd', 'bshd']:
batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[0], query_layer.shape[1], key_layer.shape[1])
if cu_seqlens_q is None:
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device)
if cu_seqlens_kv is None:
cu_seqlens_kv = torch.arange(
0,
(batch_size + 1) * max_seqlen_kv,
step=max_seqlen_kv,
dtype=torch.int32,
device=key_layer.device)
global _cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv
batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[0], query_layer.shape[1], key_layer.shape[1])
if qkv_format == 'thd':
assert (cp_group is None or get_distributed_world_size(cp_group) == 1
), "thd format is not supported for context parallelism!"
if qkv_format in ['sbhd', 'bshd']:
if not context_parallel:
# [b * s, h, d]
query_layer, key_layer, value_layer = [
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
for x in [query_layer, key_layer, value_layer]
]
if attn_mask_type == 'padding':
assert not context_parallel, "Padding mask not supported with context parallelism."
if self.attention_type == "self":
assert (
max_seqlen_q == max_seqlen_kv
), "Maximum sequence length for Q and KV should be the same."
if self.layer_number == 1:
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask)
_cu_seqlens_kv = _cu_seqlens_q
query_layer_packed, key_layer_packed, value_layer_packed = PackTensors.apply(
_indices_q, query_layer, key_layer, value_layer
)
else:
if self.layer_number == 1:
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask[0])
_cu_seqlens_kv, _indices_kv = get_cu_seqlens_and_indices(attention_mask[1])
query_layer_packed = PackTensors.apply(_indices_q, query_layer)
key_layer_packed, value_layer_packed = PackTensors.apply(
_indices_kv, key_layer, value_layer
)
query_layer, key_layer, value_layer = (
query_layer_packed, key_layer_packed, value_layer_packed)
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
else:
if cu_seqlens_q is None:
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device)
if cu_seqlens_kv is None:
cu_seqlens_kv = torch.arange(
0,
(batch_size + 1) * max_seqlen_kv,
step=max_seqlen_kv,
dtype=torch.int32,
device=key_layer.device)
elif qkv_format == 'thd':
assert not context_parallel, "thd format is not supported for context parallelism!"
assert (_flash_attn_2_available
), "flash-attn v2 is required for variable sequence length support!"
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
......@@ -988,41 +1229,37 @@ class FlashAttention(torch.nn.Module):
max_seqlen_q = seqlens_q.max().item()
max_seqlen_kv = seqlens_kv.max().item()
if cp_group is None or get_distributed_world_size(cp_group) == 1:
# [b * s, h, d]
query_layer, key_layer, value_layer = [
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
for x in [query_layer, key_layer, value_layer]
]
if context_parallel:
with self.attention_dropout_ctx():
fa_optional_forward_kwargs = {}
if not _flash_attn_2_available:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
output = flash_attn_forward_func(
output = flash_attn_forward_func_with_cp(
query_layer, key_layer, value_layer,
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor,
causal=attn_mask_type=="causal",
**fa_optional_forward_kwargs
deterministic=self.deterministic
)
else:
with self.attention_dropout_ctx():
output = flash_attn_forward_func_with_cp(
fa_optional_forward_kwargs = {}
if not _flash_attn_2_available:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
output = flash_attn_forward_func(
query_layer, key_layer, value_layer,
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor,
causal=attn_mask_type=="causal",
deterministic=self.deterministic
softmax_scale=1.0/self.norm_factor, causal=attn_mask_type=="causal",
**fa_optional_forward_kwargs
)
if attn_mask_type == 'padding':
output = UnpackTensor.apply(_indices_q, batch_size * max_seqlen_q, output)
if qkv_format == 'sbhd':
# (bs)hd -> bs(hd) -> sb(hd)
output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous()
if qkv_format == 'bshd':
elif qkv_format == 'bshd':
# (bs)hd -> bs(hd)
output = output.view(batch_size, max_seqlen_q, -1).contiguous()
......@@ -1376,8 +1613,8 @@ class DotProductAttention(torch.nn.Module):
.. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`attn_mask_type` is set to `"causal"`.
Argument :attr:`attention_mask` in the `forward` call is only used when
:attr:`self_attn_mask_type` includes `"padding"` or `"arbitrary"`.
.. warning::
......@@ -1402,6 +1639,21 @@ class DotProductAttention(torch.nn.Module):
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
attention_dropout: float, default = 0.0
dropout probability for the dropout op during multi-head attention.
attn_mask_type: str, default = `causal`
type of attention mask passed into softmax operation, options are "`causal`",
"`padding`", "`arbitrary`", "`no_mask`". For the "`causal`" mask,
TransformerEngine calculates and applies an upper triangular mask to
the softmax input. An "`arbitrary`" mask is an arbitrary user defined mask
broadcastable to the shape of softmax input. The "`padding`" mask is used
for providing locations of padded tokens in the batch, which should be of
the shape [batch_size, 1, 1, seq_len]. No mask is applied for the "`no_mask`"
option. For the `"arbitrary"` and `"padding"` mask types, the argument
:attr:`attention_mask` must be passed into `forward` call. The "`causal`"
mask can also be applied in conjunction with "`padding`" mask by passing
in multiple mask type as a comma separated string, for example,
`attn_mask_type="causal,padding"`.
attention_type: str, default = `self`
type of attention, either "`self`" and "`cross`".
layer_number: int, default = `None`
layer number of the current `DotProductAttention` when multiple such modules
are concatenated, for instance in consecutive transformer blocks.
......@@ -1415,7 +1667,7 @@ class DotProductAttention(torch.nn.Module):
have different lengths. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `_get_qkv_layout` to gain the layout information.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
type of attention mask passed into softmax operation. Overridden by
:attr:`attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different
......@@ -1456,7 +1708,7 @@ class DotProductAttention(torch.nn.Module):
layer_number: Optional[int] = None,
attention_type: str = "self",
cp_group: Optional[dist_group_type] = None,
cp_global_ranks: Union[int] = None,
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
) -> None:
super().__init__()
......@@ -1507,23 +1759,31 @@ class DotProductAttention(torch.nn.Module):
and self.device_compute_capability >= 8.0
)
assert (
attention_type in AttnTypes
), f"attention_type {attention_type} not supported"
self.attention_type = attention_type
self.attention_dropout = attention_dropout
attn_kwargs = {
"attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx,
}
self.attention_type = attention_type
self.attention_dropout = attention_dropout
if self.use_flash_attention:
self.flash_attention = FlashAttention(
norm_factor, **attn_kwargs,
deterministic=self.deterministic)
self.flash_attention = FlashAttention(norm_factor,
attention_type=attention_type,
layer_number=layer_number,
deterministic=self.deterministic,
**attn_kwargs)
# Instantiating three types since use of flash-attn and FusedAttention
# might be ruled out due to forward inputs.
if self.use_fused_attention:
self.fused_attention = FusedAttention(
norm_factor, **attn_kwargs,
attention_type = attention_type)
attention_type=attention_type)
self.unfused_attention = UnfusedDotProductAttention(
norm_factor, **attn_kwargs, layer_number=layer_number)
......@@ -1554,7 +1814,7 @@ class DotProductAttention(torch.nn.Module):
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
qkv_format: Optional[str] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
......@@ -1569,8 +1829,8 @@ class DotProductAttention(torch.nn.Module):
.. note::
Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
is set to `"causal"`.
Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
includes '"padding"' or `"arbitrary"`.
.. note::
......@@ -1614,6 +1874,9 @@ class DotProductAttention(torch.nn.Module):
Key tensor.
value_layer : torch.Tensor
Value tensor.
attention_mask : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
Boolean tensor used to mask out softmax input when not using flash-attn.
Can be a tuple of 2 masks for cross attention with padding masks.
qkv_format: str, default = `None`
If provided, overrides :attr:`qkv_format` from initialization.
cu_seqlens_q: Optional[torch.Tensor], default = `None`
......@@ -1622,9 +1885,7 @@ class DotProductAttention(torch.nn.Module):
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`,
with shape [batch_size + 1] and dtype torch.int32.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out softmax input when not using flash-attn.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `None`
attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `None`
type of attention mask passed into softmax operation.
checkpoint_core_attention : bool, default = `False`
If true, forward activations for attention are recomputed
......@@ -1639,6 +1900,10 @@ class DotProductAttention(torch.nn.Module):
Whether to use the fast path to set output tensors to 0 or not.
"""
assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), 'DotProductAttention only supports CUDA tensors.'
assert (key_layer.shape == value_layer.shape
), "Keys and values must have the same shape!"
......@@ -1646,6 +1911,7 @@ class DotProductAttention(torch.nn.Module):
attn_mask_type = self.attn_mask_type
if qkv_format is None:
qkv_format = self.qkv_format
attn_mask_type, causal_mask = _unpack_attn_mask_type(attn_mask_type)
assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition
......@@ -1691,15 +1957,23 @@ class DotProductAttention(torch.nn.Module):
qkv_layout = _get_qkv_layout(query_layer, key_layer, value_layer,
qkv_format = qkv_format)
# The priority for attention backends (subject to availability and clearing the filters)
# is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
use_flash_attention = self.use_flash_attention
use_fused_attention = self.use_fused_attention
# The following section filters out some backends based on
# certain asserts before executing the forward pass.
# Filter: Input type.
if (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
):
use_flash_attention = False
use_fused_attention = False
# Filter: Device and dimensions.
if key_layer.shape[-1] > 64:
if self.device_compute_capability in (8.6, 8.7):
use_flash_attention = False
......@@ -1709,17 +1983,31 @@ class DotProductAttention(torch.nn.Module):
if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads:
use_flash_attention = False
if attn_mask_type == "padding" and attention_mask is not None:
use_flash_attention = False
use_fused_attention = False
if core_attention_bias_type != "no_bias" or core_attention_bias is not None:
use_flash_attention = False
# Filter: ONNX export.
if is_in_onnx_export_mode():
use_flash_attention = False
use_fused_attention = False
# Filter: Attention mask type.
# attn_mask_type(s) | supported backends
# ------------------------------------------------
# causal | All
# padding | UnfusedDotProductAttention, FlashAttention
# arbitrary | UnfusedDotProductAttention
# no_mask | All
# causal + padding | FlashAttention
#
if attn_mask_type == "arbitrary":
use_flash_attention = False
use_fused_attention = False
elif attn_mask_type == "padding" and causal_mask:
assert use_flash_attention, "No attention backend available for causal + padding masks."
elif attn_mask_type == "padding":
use_fused_attention = False
if use_fused_attention:
fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype],
......@@ -1750,21 +2038,25 @@ class DotProductAttention(torch.nn.Module):
query_layer,
key_layer,
value_layer,
qkv_layout = qkv_layout,
cu_seqlens_q = cu_seqlens_q,
cu_seqlens_kv = cu_seqlens_kv,
attn_mask_type = attn_mask_type,
cp_group = self.cp_group,
cp_global_ranks = self.cp_global_ranks,
cp_stream = self.cp_stream)
return self.flash_attention(query_layer, key_layer, value_layer,
qkv_layout = qkv_layout,
cu_seqlens_q = cu_seqlens_q,
cu_seqlens_kv = cu_seqlens_kv,
attn_mask_type = attn_mask_type,
cp_group = self.cp_group,
cp_global_ranks = self.cp_global_ranks,
cp_stream = self.cp_stream)
attention_mask=attention_mask,
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream)
return self.flash_attention(query_layer,
key_layer,
value_layer,
attention_mask=attention_mask,
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream)
assert (
self.cp_group is None or get_distributed_world_size(self.cp_group) == 1
......@@ -1854,7 +2146,7 @@ class MultiheadAttention(torch.nn.Module):
layer_number: int, default = `None`
layer number of the current `TransformerLayer` when multiple such modules are
concatenated to form a transformer block.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
type of attention mask passed into softmax operation. Overridden by
:attr:`attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different
......@@ -2149,7 +2441,7 @@ class MultiheadAttention(torch.nn.Module):
def set_context_parallel_running(
self,
cp_group: Union[dist_group_type, None],
cp_global_ranks: Union[int],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group and CP dual-stream running"""
......@@ -2160,7 +2452,7 @@ class MultiheadAttention(torch.nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
encoder_output: Optional[torch.Tensor] = None,
attn_mask_type: Optional[str] = None,
is_first_microbatch: Optional[bool] = None,
......@@ -2185,7 +2477,7 @@ class MultiheadAttention(torch.nn.Module):
Input tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `None`
attn_mask_type: {'causal', 'padding', 'no_mask', arbitrary}, default = `None`
type of attention mask passed into softmax operation.
encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using
......@@ -2230,6 +2522,7 @@ class MultiheadAttention(torch.nn.Module):
assert (core_attention_bias_type in AttnBiasTypes
), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
......
......@@ -22,7 +22,7 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16,
}
AttnMaskTypes = ("causal", "padding", "no_mask")
AttnMaskTypes = ("causal", "padding", "arbitrary", "no_mask")
AttnTypes = ("self", "cross")
......
......@@ -261,21 +261,22 @@ class FusedScaleMaskSoftmax(nn.Module):
scale is None or self.softmax_in_fp32
), "softmax should be in fp32 when scaled"
if self.is_kernel_available(*inp.size()) and not is_in_onnx_export_mode():
if self.is_kernel_available(mask, *inp.size()) and not is_in_onnx_export_mode():
return self.forward_fused_softmax(inp, mask, scale)
return self.forward_torch_softmax(inp, mask, scale)
def is_kernel_available(self, b: int, np: int, sq: int, sk: int) -> bool:
def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: int) -> bool:
"""Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches = b * np
if ( # pylint: disable=too-many-boolean-expressions
self.scaled_masked_softmax_fusion # user want to fuse
self.scaled_masked_softmax_fusion # user wants to fuse
and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048
and sk % 8 == 0 # sk must be divisor of 8
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
and self.attn_mask_type != "arbitrary" # Custom masks not supported
):
if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(int(sk))
......@@ -283,6 +284,14 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.attn_mask_type == "causal":
if attn_batches % batch_per_block == 0:
return True
elif self.attn_mask_type == "padding":
if (
mask is not None
and sq % batch_per_block == 0
and mask.shape[-2] == sq
and mask.shape[-1] == sk
):
return True
else:
if sq % batch_per_block == 0:
return True
......@@ -303,7 +312,7 @@ class FusedScaleMaskSoftmax(nn.Module):
probs = ScaledUpperTriangMaskedSoftmax.apply(inp, scale)
return probs.view(b, np, sq, sk)
# input is 4D tensor (b, np, sq, sk)
if mask is not None:
if mask is not None and self.attn_mask_type != "no_mask":
return ScaledMaskedSoftmax.apply(inp, mask, scale)
return ScaledSoftmax.apply(inp, scale)
......@@ -325,7 +334,9 @@ class FusedScaleMaskSoftmax(nn.Module):
else:
mask = _get_default_causal_mask(inp.size(2))
mask_output = self.mask_func(inp, mask) if mask is not None else inp
mask_output = inp
if mask is not None and self.attn_mask_type != "no_mask":
mask_output = self.mask_func(inp, mask)
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
......
......@@ -6,7 +6,7 @@
import os
import warnings
from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
......@@ -127,7 +127,7 @@ class TransformerLayer(torch.nn.Module):
kv_channels: int, default = `None`
number of key-value channels. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
self_attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
self_attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
type of attention mask passed into softmax operation. Overridden by
:attr:`self_attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different
......@@ -429,7 +429,7 @@ class TransformerLayer(torch.nn.Module):
def set_context_parallel_running(
self,
cp_group: Union[dist_group_type, None],
cp_global_ranks: Union[int],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group and CP dual-stream running"""
......@@ -460,16 +460,17 @@ class TransformerLayer(torch.nn.Module):
.. note::
Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type`
is set to `"causal"`.
Argument :attr:`attention_mask` is only used when :attr:`self_attn_mask_type`
includes `"padding"` or `"arbitrary"`.
Parameters
----------
hidden_states : torch.Tensor
Input tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input.
self_attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
attention_mask : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
Boolean tensor used to mask out self-attention softmax input.
Can be a tuple of 2 masks for cross attention with padding masks.
self_attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
type of attention mask passed into softmax operation.
encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment