Unverified Commit 47ca514a authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Support packed input for FA (#302)



* initial changes [wip]
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add padding mask support for FA
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm causal mask from tests and add padding
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix some conflicts
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* conflicts
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add unpadding mask
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix padding mask
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [wip] fix API
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add packing and unpacking
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* docs fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix atomic_add bf16 torch.compile
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Generate non all True masks
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Lint fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix core attention export and FusedAttn filter
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix all ONNX tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Memory optimization
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Optimizations and caching fixes in torch.dynamo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Padding optimizations
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes and reviews
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d3157e2a
...@@ -612,14 +612,13 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= ...@@ -612,14 +612,13 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
block = _test_e2e_checkpointing_get_model(config, dtype) block = _test_e2e_checkpointing_get_model(config, dtype)
for _ in range(steps // 2): for _ in range(steps // 2):
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_hidden_states,
te_inp_attn_mask, None,
) )
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
...@@ -650,7 +649,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= ...@@ -650,7 +649,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
for _ in range(steps // 2): for _ in range(steps // 2):
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_hidden_states,
te_inp_attn_mask, None,
) )
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
......
...@@ -316,9 +316,9 @@ def get_attn_mask_str(use_mask, attn_mask_type): ...@@ -316,9 +316,9 @@ def get_attn_mask_str(use_mask, attn_mask_type):
# See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names. # See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names.
if attn_mask_type is None: if attn_mask_type is None:
return "_mask" if use_mask else "_no-mask" return "_mask" if use_mask else "_no-mask"
attn_mask_str = "_padding-no-mask" attn_mask_str = "_arbitrary-no-mask"
attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str
attn_mask_str = "_padding-mask" if use_mask and attn_mask_type == "padding" else attn_mask_str attn_mask_str = "_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str
return attn_mask_str return attn_mask_str
...@@ -986,13 +986,13 @@ def test_export_layernorm_mlp( ...@@ -986,13 +986,13 @@ def test_export_layernorm_mlp(
@skip_FP8 @skip_FP8
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision, use_mask, attn_mask_type", [ "precision, use_mask, attn_mask_type", [
(torch.float32, True, "padding"), # calls forward_torch_softmax (apply user mask) (torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask) (torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) (torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.float16, True, "padding"), # calls forward_torch_softmax (apply user mask) (torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) (torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) (torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.bfloat16, True, "padding"), # calls forward_torch_softmax (apply user mask) (torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) (torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
]) ])
def test_export_core_attention( def test_export_core_attention(
...@@ -1014,7 +1014,7 @@ def test_export_core_attention( ...@@ -1014,7 +1014,7 @@ def test_export_core_attention(
attention_mask = None attention_mask = None
if use_mask: if use_mask:
# Generate a random mask with 50% probability for 0 or 1. # Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision) probs = 0.5 * torch.ones(batch_size, 1, 1, seq_len, device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (query_layer, key_layer, value_layer, attention_mask) inp = (query_layer, key_layer, value_layer, attention_mask)
...@@ -1043,9 +1043,8 @@ def test_export_core_attention( ...@@ -1043,9 +1043,8 @@ def test_export_core_attention(
test_configs_multihead_attention = [ test_configs_multihead_attention = [
#"use_mask, attn_mask_type" #"use_mask, attn_mask_type"
(False, "causal"), # calls ScaledUpperTriangMaskedSoftmax (False, "no_mask"), # calls ScaledSoftmax
(True, "padding"), # calls ScaledMaskedSoftmax (True, "arbitrary"), # calls ScaledMaskedSoftmax
(False, "padding"), # calls ScaledSoftmax
] ]
test_configs_attention_type = [ test_configs_attention_type = [
#"input_layernorm, attention_type, fuse_qkv_params" #"input_layernorm, attention_type, fuse_qkv_params"
......
...@@ -157,18 +157,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -157,18 +157,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda() ).cuda()
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = ( te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
...@@ -193,18 +182,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_ ...@@ -193,18 +182,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
te_inp_attn_mask = ( te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
...@@ -233,18 +211,24 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -233,18 +211,24 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
te_inp_attn_mask = (
torch.rand( if skip_wgrad:
( _disable_wgrads(block)
1,
1, use_fp8 = fp8_recipe is not None
config.seq_len, with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
config.seq_len, te_out = block(te_inp_hidden_states)
) loss = te_out.sum()
) loss.backward()
.cuda() torch.cuda.synchronize()
.bool()
)
def _test_sanity_e2e_bert(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = torch.rand(torch.Size([bs, 1, 1, config.seq_len])).cuda() > 0.5
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
...@@ -261,18 +245,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -261,18 +245,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
te_inp_attn_mask = ( te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
torch.rand( enc_dec_attn_mask = torch.rand(torch.Size([bs, 1, 1, config.seq_len])).cuda() > 0.5
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
...@@ -282,7 +256,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -282,7 +256,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_hidden_states,
attention_mask=te_inp_attn_mask, attention_mask=te_inp_attn_mask,
encoder_output=te_inp_hidden_states encoder_output=te_inp_hidden_states,
enc_dec_attn_mask=enc_dec_attn_mask,
) )
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
...@@ -541,13 +516,14 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam ...@@ -541,13 +516,14 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam
apply_residual_connection_post_layernorm=True, apply_residual_connection_post_layernorm=True,
output_layernorm=True, output_layernorm=True,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
self_attn_mask_type="padding",
normalization=normalization, normalization=normalization,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_e2e_bert(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
......
...@@ -8,11 +8,12 @@ import warnings ...@@ -8,11 +8,12 @@ import warnings
import math import math
from importlib.metadata import version from importlib.metadata import version
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union, Dict, List from typing import Any, Callable, List, Optional, Tuple, Union, Dict
from pkg_resources import packaging from pkg_resources import packaging
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import ( from transformer_engine.pytorch.cpp_extensions.fused_attn import (
...@@ -50,6 +51,7 @@ from transformer_engine.pytorch.distributed import ( ...@@ -50,6 +51,7 @@ from transformer_engine.pytorch.distributed import (
checkpoint, checkpoint,
) )
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser
_flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("1.0.6") _flash_attn_version_required = packaging.version.Version("1.0.6")
...@@ -65,9 +67,210 @@ else: ...@@ -65,9 +67,210 @@ else:
from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward
_cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv = None, None, None, None
__all__ = ["DotProductAttention", "MultiheadAttention"] __all__ = ["DotProductAttention", "MultiheadAttention"]
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
tensor of shape [batch_size + 1,] containing the cumulative sequence
lengths of every sample in the batch and the indices containing valid
samples.
"""
mask = mask.squeeze(1).squeeze(1)
bs, seqlen = mask.shape
reduced_mask = mask.sum(dim=1)
cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
zero = torch.zeros(1, dtype=torch.int32, device="cuda")
cu_seqlens = torch.cat((zero, cu_seqlens))
mask = mask.reshape(-1)
indices = mask.nonzero()
indices = indices.unsqueeze(-1)
num_nonzeros = indices.shape[0]
pad_amount = bs * seqlen - num_nonzeros
indices = F.pad(input=indices, pad=(0, 0, 0, 0, 0, pad_amount),
mode="constant", value=float(bs * seqlen))
return cu_seqlens, indices
@jit_fuser
def pack_tensor(
indices: torch.Tensor,
tensor: torch.Tensor,
) -> torch.Tensor:
"""
Packs the given tensor using the `indices`.
"""
padding_indice = torch.zeros(
1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device)
tensor = torch.cat((tensor, padding_indice), dim=0)
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
packed = torch.gather(tensor, 0, indices)
return packed
@jit_fuser
def pack_2_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Packs the given 2 tensors using the `indices`.
"""
t1_packed = pack_tensor(indices, t1)
t2_packed = pack_tensor(indices, t2)
return t1_packed, t2_packed
@jit_fuser
def pack_3_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
t2: torch.Tensor,
t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Packs the given 3 tensors using the `indices`.
"""
t1_packed = pack_tensor(indices, t1)
t2_packed = pack_tensor(indices, t2)
t3_packed = pack_tensor(indices, t3)
return t1_packed, t2_packed, t3_packed
@jit_fuser
def unpack_tensor(
indices: torch.Tensor,
dim0: int,
tensor: torch.Tensor,
) -> torch.Tensor:
"""
Inverse of `pack_tensor`.
"""
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
unpacked = torch.zeros(
dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device)
unpacked.scatter_(0, indices, tensor)
unpacked = unpacked[0:-1,:,:]
return unpacked
@jit_fuser
def unpack_2_tensors(
indices: torch.Tensor,
dim0: int,
t1: torch.Tensor,
t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Inverse of `pack_2_tensors`.
"""
t1_unpacked = unpack_tensor(indices, dim0, t1)
t2_unpacked = unpack_tensor(indices, dim0, t2)
return t1_unpacked, t2_unpacked
@jit_fuser
def unpack_3_tensors(
indices: torch.Tensor,
dim0: int,
t1: torch.Tensor,
t2: torch.Tensor,
t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Inverse of `pack_3_tensors`.
"""
t1_unpacked = unpack_tensor(indices, dim0, t1)
t2_unpacked = unpack_tensor(indices, dim0, t2)
t3_unpacked = unpack_tensor(indices, dim0, t3)
return t1_unpacked, t2_unpacked, t3_unpacked
class PackTensors(torch.autograd.Function):
"""
Autograd function to pack tensors.
"""
@staticmethod
def forward(
ctx,
indices: torch.Tensor,
*tensors: Tuple[torch.Tensor, ...]
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
ctx.indices = indices
ctx.dim0 = tensors[0].shape[0]
if len(tensors) == 1:
return pack_tensor(indices, *tensors)
if len(tensors) == 2:
return pack_2_tensors(indices, *tensors)
return pack_3_tensors(indices, *tensors)
@staticmethod
def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
if len(grad_outputs) == 1:
return None, unpack_tensor(ctx.indices, ctx.dim0, *grad_outputs)
if len(grad_outputs) == 2:
return None, *unpack_2_tensors(ctx.indices, ctx.dim0, *grad_outputs)
return None, *unpack_3_tensors(ctx.indices, ctx.dim0, *grad_outputs)
class UnpackTensor(torch.autograd.Function):
"""
Autograd function to unpack a tensor.
"""
@staticmethod
def forward(
ctx,
indices: torch.Tensor,
dim0: int,
tensor: torch.Tensor,
) -> torch.Tensor:
ctx.indices = indices
return unpack_tensor(indices, dim0, tensor)
@staticmethod
def backward(ctx, grad_output):
return None, None, pack_tensor(ctx.indices, grad_output)
def _unpack_attn_mask_type(attn_mask_type: str) -> Tuple[str, bool]:
"""
Unpacks the attention mask type string and returns a single mask type
and a boolean for whether to apply causal mask. Also ensures that the
combination of masks passed in is supported by one of the attention
backends available.
"""
mask_types = attn_mask_type.split(',')
assert (
all(mask_type in AttnMaskTypes for mask_type in mask_types)
), f"Mask type {attn_mask_type} is not supported."
# Whether or not to apply causal mask toggle.
causal_mask = False
if "causal" in mask_types:
mask_types.remove("causal")
causal_mask = True
if len(mask_types) == 0: # Only apply causal mask.
return "causal", True
if len(mask_types) == 1 and causal_mask: # Causal + padding masks
assert mask_types[0] == "padding", f"Causal + {mask_types[0]} masking not supported."
return "padding", True
if len(mask_types) == 1: # Arbitrary or padding or no_mask
return mask_types[0], False
raise RuntimeError("Unsupported combination of mask types.")
def flash_attn_p2p_communicate(rank, send_tensor, send_dst, def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
recv_tensor, recv_src, recv_tensor, recv_src,
cp_group, batch_p2p_comm): cp_group, batch_p2p_comm):
...@@ -608,7 +811,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -608,7 +811,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -900,6 +1103,8 @@ class FlashAttention(torch.nn.Module): ...@@ -900,6 +1103,8 @@ class FlashAttention(torch.nn.Module):
norm_factor: float, norm_factor: float,
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext, attention_dropout_ctx: Optional[Callable] = nullcontext,
attention_type: str = "self",
layer_number: Optional[int] = None,
deterministic: bool = False, deterministic: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -911,6 +1116,8 @@ class FlashAttention(torch.nn.Module): ...@@ -911,6 +1116,8 @@ class FlashAttention(torch.nn.Module):
self.norm_factor = norm_factor self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.attention_type = attention_type
self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic self.deterministic = deterministic
def forward( def forward(
...@@ -918,12 +1125,13 @@ class FlashAttention(torch.nn.Module): ...@@ -918,12 +1125,13 @@ class FlashAttention(torch.nn.Module):
query_layer: torch.Tensor, query_layer: torch.Tensor,
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor, value_layer: torch.Tensor,
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
cp_group: Optional[dist_group_type] = None, cp_group: Optional[dist_group_type] = None,
cp_global_ranks: Union[int] = None, cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None, cp_stream: torch.cuda.Stream = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""flash-attn fprop""" """flash-attn fprop"""
...@@ -940,6 +1148,8 @@ class FlashAttention(torch.nn.Module): ...@@ -940,6 +1148,8 @@ class FlashAttention(torch.nn.Module):
qkv_layout in QKVLayouts qkv_layout in QKVLayouts
), f"FlashAttention does not support qkv_layout = {qkv_layout}!" ), f"FlashAttention does not support qkv_layout = {qkv_layout}!"
context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1)
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
if qkv_format == 'sbhd': if qkv_format == 'sbhd':
...@@ -953,14 +1163,47 @@ class FlashAttention(torch.nn.Module): ...@@ -953,14 +1163,47 @@ class FlashAttention(torch.nn.Module):
else: else:
query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous() query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
for x in (query_layer, key_layer, value_layer)] for x in (query_layer, key_layer, value_layer)]
elif qkv_format == 'bshd':
if qkv_format == 'bshd':
query_layer, key_layer, value_layer = [x.contiguous() query_layer, key_layer, value_layer = [x.contiguous()
for x in (query_layer, key_layer, value_layer)] for x in (query_layer, key_layer, value_layer)]
if qkv_format in ['sbhd', 'bshd']: global _cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv
batch_size, max_seqlen_q, max_seqlen_kv = ( batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[0], query_layer.shape[1], key_layer.shape[1]) query_layer.shape[0], query_layer.shape[1], key_layer.shape[1])
if qkv_format in ['sbhd', 'bshd']:
if not context_parallel:
# [b * s, h, d]
query_layer, key_layer, value_layer = [
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
for x in [query_layer, key_layer, value_layer]
]
if attn_mask_type == 'padding':
assert not context_parallel, "Padding mask not supported with context parallelism."
if self.attention_type == "self":
assert (
max_seqlen_q == max_seqlen_kv
), "Maximum sequence length for Q and KV should be the same."
if self.layer_number == 1:
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask)
_cu_seqlens_kv = _cu_seqlens_q
query_layer_packed, key_layer_packed, value_layer_packed = PackTensors.apply(
_indices_q, query_layer, key_layer, value_layer
)
else:
if self.layer_number == 1:
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask[0])
_cu_seqlens_kv, _indices_kv = get_cu_seqlens_and_indices(attention_mask[1])
query_layer_packed = PackTensors.apply(_indices_q, query_layer)
key_layer_packed, value_layer_packed = PackTensors.apply(
_indices_kv, key_layer, value_layer
)
query_layer, key_layer, value_layer = (
query_layer_packed, key_layer_packed, value_layer_packed)
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
else:
if cu_seqlens_q is None: if cu_seqlens_q is None:
cu_seqlens_q = torch.arange( cu_seqlens_q = torch.arange(
0, 0,
...@@ -975,10 +1218,8 @@ class FlashAttention(torch.nn.Module): ...@@ -975,10 +1218,8 @@ class FlashAttention(torch.nn.Module):
step=max_seqlen_kv, step=max_seqlen_kv,
dtype=torch.int32, dtype=torch.int32,
device=key_layer.device) device=key_layer.device)
elif qkv_format == 'thd':
if qkv_format == 'thd': assert not context_parallel, "thd format is not supported for context parallelism!"
assert (cp_group is None or get_distributed_world_size(cp_group) == 1
), "thd format is not supported for context parallelism!"
assert (_flash_attn_2_available assert (_flash_attn_2_available
), "flash-attn v2 is required for variable sequence length support!" ), "flash-attn v2 is required for variable sequence length support!"
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
...@@ -988,41 +1229,37 @@ class FlashAttention(torch.nn.Module): ...@@ -988,41 +1229,37 @@ class FlashAttention(torch.nn.Module):
max_seqlen_q = seqlens_q.max().item() max_seqlen_q = seqlens_q.max().item()
max_seqlen_kv = seqlens_kv.max().item() max_seqlen_kv = seqlens_kv.max().item()
if cp_group is None or get_distributed_world_size(cp_group) == 1: if context_parallel:
# [b * s, h, d]
query_layer, key_layer, value_layer = [
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
for x in [query_layer, key_layer, value_layer]
]
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
fa_optional_forward_kwargs = {} output = flash_attn_forward_func_with_cp(
if not _flash_attn_2_available:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
output = flash_attn_forward_func(
query_layer, key_layer, value_layer, query_layer, key_layer, value_layer,
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor, softmax_scale=1.0/self.norm_factor,
causal=attn_mask_type=="causal", causal=attn_mask_type=="causal",
**fa_optional_forward_kwargs deterministic=self.deterministic
) )
else: else:
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
output = flash_attn_forward_func_with_cp( fa_optional_forward_kwargs = {}
if not _flash_attn_2_available:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
output = flash_attn_forward_func(
query_layer, key_layer, value_layer, query_layer, key_layer, value_layer,
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream, softmax_scale=1.0/self.norm_factor, causal=attn_mask_type=="causal",
softmax_scale=1.0/self.norm_factor, **fa_optional_forward_kwargs
causal=attn_mask_type=="causal",
deterministic=self.deterministic
) )
if attn_mask_type == 'padding':
output = UnpackTensor.apply(_indices_q, batch_size * max_seqlen_q, output)
if qkv_format == 'sbhd': if qkv_format == 'sbhd':
# (bs)hd -> bs(hd) -> sb(hd) # (bs)hd -> bs(hd) -> sb(hd)
output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous() output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous()
if qkv_format == 'bshd': elif qkv_format == 'bshd':
# (bs)hd -> bs(hd) # (bs)hd -> bs(hd)
output = output.view(batch_size, max_seqlen_q, -1).contiguous() output = output.view(batch_size, max_seqlen_q, -1).contiguous()
...@@ -1376,8 +1613,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -1376,8 +1613,8 @@ class DotProductAttention(torch.nn.Module):
.. note:: .. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when Argument :attr:`attention_mask` in the `forward` call is only used when
:attr:`attn_mask_type` is set to `"causal"`. :attr:`self_attn_mask_type` includes `"padding"` or `"arbitrary"`.
.. warning:: .. warning::
...@@ -1402,6 +1639,21 @@ class DotProductAttention(torch.nn.Module): ...@@ -1402,6 +1639,21 @@ class DotProductAttention(torch.nn.Module):
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
attention_dropout: float, default = 0.0 attention_dropout: float, default = 0.0
dropout probability for the dropout op during multi-head attention. dropout probability for the dropout op during multi-head attention.
attn_mask_type: str, default = `causal`
type of attention mask passed into softmax operation, options are "`causal`",
"`padding`", "`arbitrary`", "`no_mask`". For the "`causal`" mask,
TransformerEngine calculates and applies an upper triangular mask to
the softmax input. An "`arbitrary`" mask is an arbitrary user defined mask
broadcastable to the shape of softmax input. The "`padding`" mask is used
for providing locations of padded tokens in the batch, which should be of
the shape [batch_size, 1, 1, seq_len]. No mask is applied for the "`no_mask`"
option. For the `"arbitrary"` and `"padding"` mask types, the argument
:attr:`attention_mask` must be passed into `forward` call. The "`causal`"
mask can also be applied in conjunction with "`padding`" mask by passing
in multiple mask type as a comma separated string, for example,
`attn_mask_type="causal,padding"`.
attention_type: str, default = `self`
type of attention, either "`self`" and "`cross`".
layer_number: int, default = `None` layer_number: int, default = `None`
layer number of the current `DotProductAttention` when multiple such modules layer number of the current `DotProductAttention` when multiple such modules
are concatenated, for instance in consecutive transformer blocks. are concatenated, for instance in consecutive transformer blocks.
...@@ -1415,7 +1667,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -1415,7 +1667,7 @@ class DotProductAttention(torch.nn.Module):
have different lengths. Please note that these formats do not reflect how have different lengths. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory. tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `_get_qkv_layout` to gain the layout information. For that, please use `_get_qkv_layout` to gain the layout information.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
type of attention mask passed into softmax operation. Overridden by type of attention mask passed into softmax operation. Overridden by
:attr:`attn_mask_type` in the `forward` method. The forward :attr:`attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different arg is useful for dynamically changing mask types, e.g. a different
...@@ -1456,7 +1708,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -1456,7 +1708,7 @@ class DotProductAttention(torch.nn.Module):
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
attention_type: str = "self", attention_type: str = "self",
cp_group: Optional[dist_group_type] = None, cp_group: Optional[dist_group_type] = None,
cp_global_ranks: Union[int] = None, cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None, cp_stream: torch.cuda.Stream = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1507,23 +1759,31 @@ class DotProductAttention(torch.nn.Module): ...@@ -1507,23 +1759,31 @@ class DotProductAttention(torch.nn.Module):
and self.device_compute_capability >= 8.0 and self.device_compute_capability >= 8.0
) )
assert (
attention_type in AttnTypes
), f"attention_type {attention_type} not supported"
self.attention_type = attention_type
self.attention_dropout = attention_dropout
attn_kwargs = { attn_kwargs = {
"attention_dropout": attention_dropout, "attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx, "attention_dropout_ctx": attention_dropout_ctx,
} }
self.attention_type = attention_type
self.attention_dropout = attention_dropout
if self.use_flash_attention: if self.use_flash_attention:
self.flash_attention = FlashAttention( self.flash_attention = FlashAttention(norm_factor,
norm_factor, **attn_kwargs, attention_type=attention_type,
deterministic=self.deterministic) layer_number=layer_number,
deterministic=self.deterministic,
**attn_kwargs)
# Instantiating three types since use of flash-attn and FusedAttention # Instantiating three types since use of flash-attn and FusedAttention
# might be ruled out due to forward inputs. # might be ruled out due to forward inputs.
if self.use_fused_attention: if self.use_fused_attention:
self.fused_attention = FusedAttention( self.fused_attention = FusedAttention(
norm_factor, **attn_kwargs, norm_factor, **attn_kwargs,
attention_type = attention_type) attention_type=attention_type)
self.unfused_attention = UnfusedDotProductAttention( self.unfused_attention = UnfusedDotProductAttention(
norm_factor, **attn_kwargs, layer_number=layer_number) norm_factor, **attn_kwargs, layer_number=layer_number)
...@@ -1554,7 +1814,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -1554,7 +1814,7 @@ class DotProductAttention(torch.nn.Module):
query_layer: torch.Tensor, query_layer: torch.Tensor,
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor, value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
qkv_format: Optional[str] = None, qkv_format: Optional[str] = None,
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
...@@ -1569,8 +1829,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -1569,8 +1829,8 @@ class DotProductAttention(torch.nn.Module):
.. note:: .. note::
Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type` Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
is set to `"causal"`. includes '"padding"' or `"arbitrary"`.
.. note:: .. note::
...@@ -1614,6 +1874,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -1614,6 +1874,9 @@ class DotProductAttention(torch.nn.Module):
Key tensor. Key tensor.
value_layer : torch.Tensor value_layer : torch.Tensor
Value tensor. Value tensor.
attention_mask : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
Boolean tensor used to mask out softmax input when not using flash-attn.
Can be a tuple of 2 masks for cross attention with padding masks.
qkv_format: str, default = `None` qkv_format: str, default = `None`
If provided, overrides :attr:`qkv_format` from initialization. If provided, overrides :attr:`qkv_format` from initialization.
cu_seqlens_q: Optional[torch.Tensor], default = `None` cu_seqlens_q: Optional[torch.Tensor], default = `None`
...@@ -1622,9 +1885,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -1622,9 +1885,7 @@ class DotProductAttention(torch.nn.Module):
cu_seqlens_kv: Optional[torch.Tensor], default = `None` cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`, Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`,
with shape [batch_size + 1] and dtype torch.int32. with shape [batch_size + 1] and dtype torch.int32.
attention_mask : Optional[torch.Tensor], default = `None` attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `None`
Boolean tensor used to mask out softmax input when not using flash-attn.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `None`
type of attention mask passed into softmax operation. type of attention mask passed into softmax operation.
checkpoint_core_attention : bool, default = `False` checkpoint_core_attention : bool, default = `False`
If true, forward activations for attention are recomputed If true, forward activations for attention are recomputed
...@@ -1639,6 +1900,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -1639,6 +1900,10 @@ class DotProductAttention(torch.nn.Module):
Whether to use the fast path to set output tensors to 0 or not. Whether to use the fast path to set output tensors to 0 or not.
""" """
assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), 'DotProductAttention only supports CUDA tensors.'
assert (key_layer.shape == value_layer.shape assert (key_layer.shape == value_layer.shape
), "Keys and values must have the same shape!" ), "Keys and values must have the same shape!"
...@@ -1646,6 +1911,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -1646,6 +1911,7 @@ class DotProductAttention(torch.nn.Module):
attn_mask_type = self.attn_mask_type attn_mask_type = self.attn_mask_type
if qkv_format is None: if qkv_format is None:
qkv_format = self.qkv_format qkv_format = self.qkv_format
attn_mask_type, causal_mask = _unpack_attn_mask_type(attn_mask_type)
assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition and value_layer.shape[-2] == self.num_gqa_groups_per_partition
...@@ -1691,15 +1957,23 @@ class DotProductAttention(torch.nn.Module): ...@@ -1691,15 +1957,23 @@ class DotProductAttention(torch.nn.Module):
qkv_layout = _get_qkv_layout(query_layer, key_layer, value_layer, qkv_layout = _get_qkv_layout(query_layer, key_layer, value_layer,
qkv_format = qkv_format) qkv_format = qkv_format)
# The priority for attention backends (subject to availability and clearing the filters)
# is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
use_flash_attention = self.use_flash_attention use_flash_attention = self.use_flash_attention
use_fused_attention = self.use_fused_attention use_fused_attention = self.use_fused_attention
# The following section filters out some backends based on
# certain asserts before executing the forward pass.
# Filter: Input type.
if (query_layer.dtype not in [torch.bfloat16, torch.float16] if (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16] or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16] or value_layer.dtype not in [torch.bfloat16, torch.float16]
): ):
use_flash_attention = False use_flash_attention = False
use_fused_attention = False
# Filter: Device and dimensions.
if key_layer.shape[-1] > 64: if key_layer.shape[-1] > 64:
if self.device_compute_capability in (8.6, 8.7): if self.device_compute_capability in (8.6, 8.7):
use_flash_attention = False use_flash_attention = False
...@@ -1709,17 +1983,31 @@ class DotProductAttention(torch.nn.Module): ...@@ -1709,17 +1983,31 @@ class DotProductAttention(torch.nn.Module):
if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads: if not _flash_attn_2_available and self.num_gqa_groups != self.num_attention_heads:
use_flash_attention = False use_flash_attention = False
if attn_mask_type == "padding" and attention_mask is not None:
use_flash_attention = False
use_fused_attention = False
if core_attention_bias_type != "no_bias" or core_attention_bias is not None: if core_attention_bias_type != "no_bias" or core_attention_bias is not None:
use_flash_attention = False use_flash_attention = False
# Filter: ONNX export.
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
use_flash_attention = False use_flash_attention = False
use_fused_attention = False use_fused_attention = False
# Filter: Attention mask type.
# attn_mask_type(s) | supported backends
# ------------------------------------------------
# causal | All
# padding | UnfusedDotProductAttention, FlashAttention
# arbitrary | UnfusedDotProductAttention
# no_mask | All
# causal + padding | FlashAttention
#
if attn_mask_type == "arbitrary":
use_flash_attention = False
use_fused_attention = False
elif attn_mask_type == "padding" and causal_mask:
assert use_flash_attention, "No attention backend available for causal + padding masks."
elif attn_mask_type == "padding":
use_fused_attention = False
if use_fused_attention: if use_fused_attention:
fused_attention_backend = tex.get_fused_attn_backend( fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype], TE_DType[query_layer.dtype],
...@@ -1750,21 +2038,25 @@ class DotProductAttention(torch.nn.Module): ...@@ -1750,21 +2038,25 @@ class DotProductAttention(torch.nn.Module):
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
qkv_layout = qkv_layout, attention_mask=attention_mask,
cu_seqlens_q = cu_seqlens_q, qkv_layout=qkv_layout,
cu_seqlens_kv = cu_seqlens_kv, cu_seqlens_q=cu_seqlens_q,
attn_mask_type = attn_mask_type, cu_seqlens_kv=cu_seqlens_kv,
cp_group = self.cp_group, attn_mask_type=attn_mask_type,
cp_global_ranks = self.cp_global_ranks, cp_group=self.cp_group,
cp_stream = self.cp_stream) cp_global_ranks=self.cp_global_ranks,
return self.flash_attention(query_layer, key_layer, value_layer, cp_stream=self.cp_stream)
qkv_layout = qkv_layout, return self.flash_attention(query_layer,
cu_seqlens_q = cu_seqlens_q, key_layer,
cu_seqlens_kv = cu_seqlens_kv, value_layer,
attn_mask_type = attn_mask_type, attention_mask=attention_mask,
cp_group = self.cp_group, qkv_layout=qkv_layout,
cp_global_ranks = self.cp_global_ranks, cu_seqlens_q=cu_seqlens_q,
cp_stream = self.cp_stream) cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream)
assert ( assert (
self.cp_group is None or get_distributed_world_size(self.cp_group) == 1 self.cp_group is None or get_distributed_world_size(self.cp_group) == 1
...@@ -1854,7 +2146,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1854,7 +2146,7 @@ class MultiheadAttention(torch.nn.Module):
layer_number: int, default = `None` layer_number: int, default = `None`
layer number of the current `TransformerLayer` when multiple such modules are layer number of the current `TransformerLayer` when multiple such modules are
concatenated to form a transformer block. concatenated to form a transformer block.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
type of attention mask passed into softmax operation. Overridden by type of attention mask passed into softmax operation. Overridden by
:attr:`attn_mask_type` in the `forward` method. The forward :attr:`attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different arg is useful for dynamically changing mask types, e.g. a different
...@@ -2149,7 +2441,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2149,7 +2441,7 @@ class MultiheadAttention(torch.nn.Module):
def set_context_parallel_running( def set_context_parallel_running(
self, self,
cp_group: Union[dist_group_type, None], cp_group: Union[dist_group_type, None],
cp_global_ranks: Union[int], cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream, cp_stream: torch.cuda.Stream,
) -> None: ) -> None:
"""Set CP group and CP dual-stream running""" """Set CP group and CP dual-stream running"""
...@@ -2160,7 +2452,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2160,7 +2452,7 @@ class MultiheadAttention(torch.nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
encoder_output: Optional[torch.Tensor] = None, encoder_output: Optional[torch.Tensor] = None,
attn_mask_type: Optional[str] = None, attn_mask_type: Optional[str] = None,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
...@@ -2185,7 +2477,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2185,7 +2477,7 @@ class MultiheadAttention(torch.nn.Module):
Input tensor. Input tensor.
attention_mask : Optional[torch.Tensor], default = `None` attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input. Boolean tensor used to mask out self-attention softmax input.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `None` attn_mask_type: {'causal', 'padding', 'no_mask', arbitrary}, default = `None`
type of attention mask passed into softmax operation. type of attention mask passed into softmax operation.
encoder_output : Optional[torch.Tensor], default = `None` encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
...@@ -2230,6 +2522,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2230,6 +2522,7 @@ class MultiheadAttention(torch.nn.Module):
assert (core_attention_bias_type in AttnBiasTypes assert (core_attention_bias_type in AttnBiasTypes
), f"core_attention_bias_type {core_attention_bias_type} is not supported!" ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
# ================================================= # =================================================
# Pre-allocate memory for key-values for inference. # Pre-allocate memory for key-values for inference.
# ================================================= # =================================================
......
...@@ -22,7 +22,7 @@ TE_DType = { ...@@ -22,7 +22,7 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16, torch.bfloat16: tex.DType.kBFloat16,
} }
AttnMaskTypes = ("causal", "padding", "no_mask") AttnMaskTypes = ("causal", "padding", "arbitrary", "no_mask")
AttnTypes = ("self", "cross") AttnTypes = ("self", "cross")
......
...@@ -261,21 +261,22 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -261,21 +261,22 @@ class FusedScaleMaskSoftmax(nn.Module):
scale is None or self.softmax_in_fp32 scale is None or self.softmax_in_fp32
), "softmax should be in fp32 when scaled" ), "softmax should be in fp32 when scaled"
if self.is_kernel_available(*inp.size()) and not is_in_onnx_export_mode(): if self.is_kernel_available(mask, *inp.size()) and not is_in_onnx_export_mode():
return self.forward_fused_softmax(inp, mask, scale) return self.forward_fused_softmax(inp, mask, scale)
return self.forward_torch_softmax(inp, mask, scale) return self.forward_torch_softmax(inp, mask, scale)
def is_kernel_available(self, b: int, np: int, sq: int, sk: int) -> bool: def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: int) -> bool:
"""Check FusedScaleMaskSoftmax kernel availability based on size""" """Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches = b * np attn_batches = b * np
if ( # pylint: disable=too-many-boolean-expressions if ( # pylint: disable=too-many-boolean-expressions
self.scaled_masked_softmax_fusion # user want to fuse self.scaled_masked_softmax_fusion # user wants to fuse
and self.input_in_float16 # input must be fp16 and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048 and 16 < sk <= 4096 # sk must be 16 ~ 2048
and sk % 8 == 0 # sk must be divisor of 8 and sk % 8 == 0 # sk must be divisor of 8
and sq % 4 == 0 # sq must be divisor of 4 and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4
and self.attn_mask_type != "arbitrary" # Custom masks not supported
): ):
if 0 <= sk <= 4096: if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(int(sk)) batch_per_block = self.get_batch_per_block(int(sk))
...@@ -283,6 +284,14 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -283,6 +284,14 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.attn_mask_type == "causal": if self.attn_mask_type == "causal":
if attn_batches % batch_per_block == 0: if attn_batches % batch_per_block == 0:
return True return True
elif self.attn_mask_type == "padding":
if (
mask is not None
and sq % batch_per_block == 0
and mask.shape[-2] == sq
and mask.shape[-1] == sk
):
return True
else: else:
if sq % batch_per_block == 0: if sq % batch_per_block == 0:
return True return True
...@@ -303,7 +312,7 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -303,7 +312,7 @@ class FusedScaleMaskSoftmax(nn.Module):
probs = ScaledUpperTriangMaskedSoftmax.apply(inp, scale) probs = ScaledUpperTriangMaskedSoftmax.apply(inp, scale)
return probs.view(b, np, sq, sk) return probs.view(b, np, sq, sk)
# input is 4D tensor (b, np, sq, sk) # input is 4D tensor (b, np, sq, sk)
if mask is not None: if mask is not None and self.attn_mask_type != "no_mask":
return ScaledMaskedSoftmax.apply(inp, mask, scale) return ScaledMaskedSoftmax.apply(inp, mask, scale)
return ScaledSoftmax.apply(inp, scale) return ScaledSoftmax.apply(inp, scale)
...@@ -325,7 +334,9 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -325,7 +334,9 @@ class FusedScaleMaskSoftmax(nn.Module):
else: else:
mask = _get_default_causal_mask(inp.size(2)) mask = _get_default_causal_mask(inp.size(2))
mask_output = self.mask_func(inp, mask) if mask is not None else inp mask_output = inp
if mask is not None and self.attn_mask_type != "no_mask":
mask_output = self.mask_func(inp, mask)
probs = torch.nn.Softmax(dim=-1)(mask_output) probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32: if self.input_in_float16 and self.softmax_in_fp32:
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import os import os
import warnings import warnings
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
import torch import torch
...@@ -127,7 +127,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -127,7 +127,7 @@ class TransformerLayer(torch.nn.Module):
kv_channels: int, default = `None` kv_channels: int, default = `None`
number of key-value channels. defaults to number of key-value channels. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`. :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
self_attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` self_attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
type of attention mask passed into softmax operation. Overridden by type of attention mask passed into softmax operation. Overridden by
:attr:`self_attn_mask_type` in the `forward` method. The forward :attr:`self_attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different arg is useful for dynamically changing mask types, e.g. a different
...@@ -429,7 +429,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -429,7 +429,7 @@ class TransformerLayer(torch.nn.Module):
def set_context_parallel_running( def set_context_parallel_running(
self, self,
cp_group: Union[dist_group_type, None], cp_group: Union[dist_group_type, None],
cp_global_ranks: Union[int], cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream, cp_stream: torch.cuda.Stream,
) -> None: ) -> None:
"""Set CP group and CP dual-stream running""" """Set CP group and CP dual-stream running"""
...@@ -460,16 +460,17 @@ class TransformerLayer(torch.nn.Module): ...@@ -460,16 +460,17 @@ class TransformerLayer(torch.nn.Module):
.. note:: .. note::
Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type` Argument :attr:`attention_mask` is only used when :attr:`self_attn_mask_type`
is set to `"causal"`. includes `"padding"` or `"arbitrary"`.
Parameters Parameters
---------- ----------
hidden_states : torch.Tensor hidden_states : torch.Tensor
Input tensor. Input tensor.
attention_mask : Optional[torch.Tensor], default = `None` attention_mask : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
Boolean tensor used to mask out self-attention softmax input. Boolean tensor used to mask out self-attention softmax input.
self_attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` Can be a tuple of 2 masks for cross attention with padding masks.
self_attn_mask_type: {'causal', 'padding', 'no_mask', 'arbitrary'}, default = `causal`
type of attention mask passed into softmax operation. type of attention mask passed into softmax operation.
encoder_output : Optional[torch.Tensor], default = `None` encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment