"qa/vscode:/vscode.git/clone" did not exist on "c525760538b5cb1b77f3d93ab2c98d75b9453f43"
Unverified Commit 5e4e0b2c authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Add sink attention support from cuDNN (#2148)



* first draft; debug plan failure
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* debug uid error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak params
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add grad in output
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up prints
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix prints in test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* address review comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix unfused grad; add softmax_type; add sink to bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix padding mask; add swa tests; remove requires_grad for off-by-one
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix indent
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix non-determinism and shapes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up prints
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add GQA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add CP A2A; dq/dk mismatches
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CP A2A; need cleaner solution
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CP A2A; pending cudnn kernel change
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix world size in unit test; avoid thd format
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix kernel_backend, dtype in unit test; fix head_dim for FP8 Hopper
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix thd logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8 context
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak CP logging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* allow no_mask/padding for SWA(left,0)
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "allow no_mask/padding for SWA(left,0)"

This reverts commit 08b4ccc67a08b6882080b06aa715f541bb832aca.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add softmax_type to Jax
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add cuDNN version control
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* prettify tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* skip 9.13 for MLA, non 192/128
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rename compare_with_error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* small cleanups and improvements
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix minor CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force sink/dsink to be float32
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* switch FE to GH FE
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* return to GH TE main FE commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FE to 1.14.1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up before CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* bump up cudnn version
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add backend selection guard for unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add docstring for softmax type enums in C
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 57b4d7bc
Subproject commit deda80e5372d50e925d7bf4f76c5db779be3fbd5
Subproject commit 1a7b4b78db44712fb9707d21cd2e3179f1fd88b8
......@@ -17,88 +17,18 @@ from test_attention_with_cp import model_configs_flash_attn, model_configs_fused
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.common.recipe import DelayedScaling
from utils import ModelConfig, compare_and_assert
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
def run_dpa_with_cp(
dtype="bf16",
model=None,
qkv_format="bshd",
kernel_backend="FlashAttention",
cp_comm_type="p2p",
fp8_mha=False,
def generate_input_shapes(
qkv_format: str,
config: ModelConfig,
world_size: int,
kernel_backend: str,
):
"""Test DotProductAttention module with context parallelism"""
# args are passed as strings
fp8_mha = fp8_mha == "True"
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
config = model_configs_flash_attn[model]
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"
if qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
device_count = torch.cuda.device_count()
device = rank % device_count
torch.cuda.set_device(device)
print(f"[INFO] world_size:{world_size}, rank:{rank}")
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
# create flash attn comm group for CP
cp_comm_ranks = range(world_size)
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert (
world_size % 2 == 0
), "Assuming CP size for A2A is 2, and CP size for P2P is (world_size // 2)!"
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)
if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha)
# instantiate core attn module
core_attn = DotProductAttention(
config.num_heads,
(config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
)
core_attn = core_attn.cuda()
# create flash attn inputs
if qkv_format == "bshd":
q_input_shape = (
config.batch_size,
......@@ -191,34 +121,158 @@ def run_dpa_with_cp(
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
assert False, f"{qkv_format=} is not supported!"
return (
q_input_shape,
k_input_shape,
v_input_shape,
attn_output_shape,
cu_seqlens_q,
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
)
def get_tols(config, dtype):
if dtype == "bf16":
if config.num_heads == config.num_gqa_groups:
atol = 2.5e-2
rtol = 2.5e-2
else:
atol = 3.5e-2
rtol = 3.5e-2
rmse_tol = 0.01
elif dtype == "fp16":
atol = 5e-3
rtol = 5e-3
rmse_tol = 0.01
elif dtype == "fp8":
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.1
else:
assert False, f"{dtype=} is not supported!"
return atol, rtol, rmse_tol
def run_dpa_with_cp(
dtype="bf16",
model=None,
qkv_format="bshd",
kernel_backend="FlashAttention",
cp_comm_type="p2p",
fp8_mha=False,
log_level=logging.WARNING,
):
"""Test DotProductAttention module with context parallelism"""
logging.root.setLevel(log_level)
# set up environment variables and config
fp8_mha = fp8_mha == "True"
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
config = model_configs_flash_attn[model]
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type=} is not supported!"
if qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
# set up distributed group
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
device_count = torch.cuda.device_count()
device = rank % device_count
torch.cuda.set_device(device)
logging.info(f"[Rank {rank}] Setup: world_size {world_size}")
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
# set up communication group for CP
cp_comm_ranks = range(world_size)
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert world_size % 2 == 0, (
"{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size"
" = 2."
)
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)
if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha)
# instantiate attention module
core_attn = DotProductAttention(
config.num_heads,
(config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
softmax_type=config.softmax_type,
).cuda()
if config.softmax_type != "vanilla":
core_attn.softmax_offset.requires_grad = True
# generate attention inputs
(
q_input_shape,
k_input_shape,
v_input_shape,
attn_output_shape,
cu_seqlens_q,
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend)
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda()
for x in [q, k, v]:
x.requires_grad = True
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
dout_quantizer = Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
scale=torch.tensor([1], dtype=torch.float32).cuda(),
amax=torch.tensor([0], dtype=torch.float32).cuda(),
)
if fp8_mha:
dout_quantizer = Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
scale=torch.tensor([1], dtype=torch.float32).cuda(),
amax=torch.tensor([0], dtype=torch.float32).cuda(),
)
# create flash attention bias
if config.attn_bias_type not in ["no_bias", "alibi"]:
attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv)
bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda()
else:
bias = None
# run core_attn without CP
for x in [q, k, v]:
x.requires_grad = True
############ run without CP ############
logging.info(f"[Rank {rank}] Run without context parallelism")
if dtype == "fp8":
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
with fp8_context:
out = core_attn(
q,
......@@ -236,8 +290,30 @@ def run_dpa_with_cp(
out.backward(dout_fp8)
else:
out.backward(dout)
dq, dk, dv = q.grad, k.grad, v.grad
d_softmax_offset = None
if config.softmax_type != "vanilla":
d_softmax_offset = core_attn.softmax_offset.grad
# run core_attn wit CP
############ run with CP ############
logging.info(f"[Rank {rank}] Run with context parallelism")
# set up environment
core_attn.set_context_parallel_group(
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)
if config.softmax_type != "vanilla":
core_attn.softmax_offset.grad.zero_()
if dtype == "fp8":
core_attn.reset_fp8_meta_tensors()
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
# set up inputs
q_, k_, v_, dout_, *rest = [
x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])
]
......@@ -267,8 +343,6 @@ def run_dpa_with_cp(
)
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None:
bias_ = bias_.view(
......@@ -276,19 +350,8 @@ def run_dpa_with_cp(
)
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)
if dtype == "fp8":
core_attn.reset_fp8_meta_tensors()
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
# run attention
with fp8_context:
out_ = core_attn(
q_,
......@@ -306,18 +369,23 @@ def run_dpa_with_cp(
out_.backward(dout_fp8_)
else:
out_.backward(dout_)
if fp8_mha:
assert isinstance(out, Float8Tensor)
assert isinstance(out_, Float8Tensor)
out = out.dequantize()
out_ = out_.dequantize()
for x in [out_, q_.grad, k_.grad, v_.grad]:
assert torch.all(~torch.isnan(x))
assert torch.all(~torch.isinf(x))
# compare results with and without CP
# get outputs
dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad
d_softmax_offset_ = None
if config.softmax_type != "vanilla":
d_softmax_offset_ = core_attn.softmax_offset.grad.clone()
for x in [out_, dq_, dk_, dv_, d_softmax_offset_]:
if x is not None:
assert torch.all(~torch.isnan(x))
assert torch.all(~torch.isinf(x))
############ compare results between CP and no-CP ############
if qkv_format == "bshd" or qkv_format == "sbhd":
dq, dk, dv, out = [
x.view(
......@@ -373,56 +441,70 @@ def run_dpa_with_cp(
).item()
== 0
)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
if dtype == "bf16":
if config.num_heads == config.num_gqa_groups:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=3.5e-2, rtol=3.5e-2)
elif dtype == "fp16":
tols = dict(atol=5e-3, rtol=5e-3)
elif dtype == "fp8":
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
else:
assert False, f"{dtype} is an unsupported dtype!"
def _rmse(a, b):
return torch.sqrt((a - b).square().mean()).item()
def _error(a, b):
if dtype != "fp8":
torch.testing.assert_close(a, b, **tols)
else:
try:
torch.testing.assert_close(a, b, **tols)
except Exception as e:
logging.debug(e)
rmse = _rmse(a, b)
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert (
rmse < rmse_tol * rmse_range
), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
if qkv_format == "bshd":
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
_error(a[:, 0], b[:, 0])
_error(a[:, 1], b[:, 1])
elif qkv_format == "sbhd":
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
_error(a[0], b[0])
_error(a[1], b[1])
elif qkv_format == "thd":
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
_error(a, b)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
atol, rtol, rmse_tol = get_tols(config, dtype)
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_]
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset]
names = ["out", "dq", "dk", "dv", "d_softmax_offset"]
names_cp = [x + "_cp" for x in names]
names_no_cp = [x + "_no_cp" for x in names]
is_fp8 = dtype == "fp8"
for i, t in enumerate(tensors_no_cp):
if t is not None:
if "softmax_offset" not in names[i]:
if qkv_format == "bshd":
compare_and_assert(
t[:, 0],
tensors_cp[i][:, 0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[:, 1],
tensors_cp[i][:, 1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
elif qkv_format == "sbhd":
compare_and_assert(
t[0],
tensors_cp[i][0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[1],
tensors_cp[i][1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
elif qkv_format == "thd":
compare_and_assert(
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
)
else:
compare_and_assert(
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
)
logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches")
# destroy distribution group
dist.destroy_process_group()
......
This diff is collapsed.
......@@ -6,6 +6,7 @@ import os
import subprocess
import sys
import pathlib
import logging
import pytest
import torch
......@@ -19,13 +20,15 @@ _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig, get_available_attention_backends
pytest_logging_level = logging.getLevelName(logging.root.level)
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
......@@ -72,6 +75,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")
config = model_configs_flash_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
......@@ -89,6 +94,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
)
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!")
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16}
available_backends, *_ = get_available_attention_backends(
config,
qkv_dtype=dtypes[dtype],
qkv_layout="_".join([qkv_format] * 3),
)
flash_attn_supported, *_ = available_backends
if not flash_attn_supported:
pytest.skip("No attention backend available.")
subprocess.run(
get_bash_arguments(
......@@ -98,13 +112,14 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
qkv_format=qkv_format,
kernel_backend="FlashAttention",
cp_comm_type=cp_comm_type,
log_level=pytest_logging_level,
),
check=True,
)
model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(
......@@ -135,6 +150,15 @@ model_configs_fused_attn = {
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
"cp_4_0": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="vanilla"
), # GQA
"cp_4_1": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one"
), # GQA
"cp_4_2": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
), # GQA
}
......@@ -158,6 +182,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("FP8 attention is only supported on sm90+!")
config = model_configs_fused_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
......@@ -191,13 +217,22 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently does not support FP8 attention!")
if dtype == "fp8" and config.softmax_type != "vanilla":
pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!")
if config.softmax_type != "vanilla" and cp_comm_type != "a2a":
pytest.skip(
"CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!"
)
if config.softmax_type != "vanilla" and qkv_format == "thd":
pytest.skip(
"CP implementation does not support qkv_format=thd for non-vanilla softmax types!"
)
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtypes[dtype],
qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn,
qkv_layout="_".join([qkv_format] * 3),
window_size=config.window_size,
context_parallel=True,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
......@@ -212,6 +247,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
kernel_backend="FusedAttention",
cp_comm_type=cp_comm_type,
fp8_mha=fp8_mha,
log_level=pytest_logging_level,
),
check=True,
)
......@@ -469,7 +469,6 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=False,
is_training=False,
fp8=is_fp8,
......
......@@ -20,6 +20,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
get_attention_backend,
AttentionParams,
AttentionLogging,
check_set_window_size,
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
......@@ -137,6 +138,31 @@ def reset_rng_states() -> None:
torch.cuda.set_rng_state(cuda_rng_state)
def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8):
if not is_fp8:
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
return
try:
if a.dtype != b.dtype:
a = a.to(b.dtype)
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e:
logging.debug(e)
rmse = torch.sqrt((a - b).square().mean()).item()
logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert rmse < rmse_tol * rmse_range, (
name_a
+ " vs "
+ name_b
+ " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
)
class ModelConfig:
def __init__(
self,
......@@ -147,12 +173,15 @@ class ModelConfig:
max_seqlen_kv: int = None,
num_gqa_groups: int = None,
head_dim_v: int = None,
softmax_type: str = "vanilla",
dropout_p: float = 0.0,
attn_mask_type: str = "no_mask",
attn_bias_type: str = "no_bias",
alibi_type: str = "none",
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
context_parallel: bool = False,
cp_comm_type: str = "p2p",
total_requests: int = None,
max_ctx_len: int = None,
num_layers: int = 1,
......@@ -171,13 +200,16 @@ class ModelConfig:
self.kv_channels = (self.head_dim_qk, self.head_dim_v)
self.hidden_size = self.num_heads * self.head_dim_qk
self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v
self.softmax_type = softmax_type
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type
self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross"
self.bias_shape = bias_shape
self.window_size = window_size
self.window_size = check_set_window_size(self.attn_mask_type, window_size)
self.context_parallel = context_parallel
self.cp_comm_type = cp_comm_type
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
self.num_layers = num_layers
......@@ -198,9 +230,7 @@ def get_available_attention_backends(
config: ModelConfig,
qkv_dtype: torch.dtype,
qkv_layout: str,
window_size: Tuple[int, int] = (-1, -1),
pad_between_seqs: bool = False,
context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
......@@ -250,19 +280,21 @@ def get_available_attention_backends(
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
window_size=config.window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
context_parallel=config.context_parallel,
cp_comm_type=config.cp_comm_type,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
is_training=is_training,
inference_params=inference_params,
softmax_type=config.softmax_type,
)
(
use_flash_attention,
......
......@@ -21,17 +21,19 @@ namespace transformer_engine {
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
......@@ -41,21 +43,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
......@@ -66,24 +69,26 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine
......
......@@ -1695,6 +1695,7 @@ void fused_attn_fp8_fwd_impl_v1(
layout,
bias_type,
mask_type,
NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0,
0,
true,
......@@ -2000,6 +2001,7 @@ void fused_attn_fp8_bwd_impl_v1(
layout,
bias_type,
mask_type,
NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0,
0,
false,
......
......@@ -107,6 +107,7 @@ struct FADescriptor_v1 {
NVTE_QKV_Layout layout;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
NVTE_Softmax_Type softmax_type;
std::int64_t window_size_left;
std::int64_t window_size_right;
bool deterministic;
......@@ -116,14 +117,15 @@ struct FADescriptor_v1 {
bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left,
window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) <
attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, bias_type, fwd_tensor_type,
bwd_tensor_type) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left,
rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type,
rhs.bwd_tensor_type);
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type,
rhs.fwd_tensor_type, rhs.bwd_tensor_type);
}
};
......
......@@ -124,6 +124,24 @@ enum NVTE_Mask_Type {
NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5,
};
/*! \enum NVTE_Softmax_Type
* \brief Attention softmax types as described in
* Efficient Streaming Language Models with Attention Sinks (https://arxiv.org/pdf/2309.17453v3).
* For a given attention score S = Q*K^T, different softmax types perform different operations on S,
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* where alpha is a learnable parameter in shape [H].
*/
enum NVTE_Softmax_Type {
/*! Vanilla softmax */
NVTE_VANILLA_SOFTMAX = 0,
/*! Off-by-one softmax */
NVTE_OFF_BY_ONE_SOFTMAX = 1,
/*! Learnable softmax */
NVTE_LEARNABLE_SOFTMAX = 2,
};
/*! \enum NVTE_Fused_Attn_Backend
* \brief Fused attention backends
*/
......@@ -178,6 +196,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] softmax_type The attention softmax type.
* \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V.
......@@ -190,9 +209,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
*/
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right);
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right);
/*! \brief Compute dot product attention with packed QKV input.
*
......@@ -224,6 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
*
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
......@@ -239,19 +260,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen,
bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
......@@ -284,6 +305,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* e.g. M, ZInv, rng_state.
* \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing,
......@@ -293,6 +315,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
......@@ -302,10 +325,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
size_t max_seqlen, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream);
......@@ -340,6 +364,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in 2HD or H2D layouts.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
......@@ -361,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
......@@ -368,13 +394,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset,
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input.
*
......@@ -409,6 +437,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
......@@ -422,6 +451,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
......@@ -431,12 +461,12 @@ void nvte_fused_attn_fwd_kvpacked(
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream);
NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V.
*
......@@ -473,6 +503,7 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] K The K tensor.
* \param[in] V The V tensor.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
......@@ -494,22 +525,24 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream);
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
*
......@@ -549,6 +582,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[out] dK The gradient of the K tensor.
* \param[out] dV The gradient of the V tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
......@@ -562,6 +596,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
......@@ -571,14 +606,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream);
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Update the RNG state with the seed and calculated offset.
*
......
......@@ -36,6 +36,10 @@
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_Softmax_Type>(m, "NVTE_Softmax_Type", pybind11::module_local()) \
.value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) \
.value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX) \
.value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX); \
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \
......
......@@ -13,6 +13,7 @@ import logging
from packaging.version import Version as PkgVersion
import torch
import torch.nn.functional as F
import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import (
SplitAlongDim,
......@@ -142,6 +143,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None,
softmax_type: str = "vanilla",
) -> None:
super().__init__()
......@@ -149,6 +151,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
self.attention_type = attention_type
self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number
self.softmax_type = softmax_type
def mask_func(x, y):
return (
......@@ -185,6 +188,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
inference_params: Optional[InferenceParams] = None,
softmax_offset: torch.Tensor = None,
) -> torch.Tensor:
"""Unfused attention fprop"""
assert (
......@@ -326,7 +330,21 @@ class UnfusedDotProductAttention(torch.nn.Module):
dtype=query_layer.dtype
)
# attention scores and attention mask [b, np, sq, sk]
# add attention sink to the last column: [b, np, sq, sk+1]
if self.softmax_type != "vanilla":
matmul_result = torch.cat(
[
matmul_result,
softmax_offset.to(dtype=matmul_result.dtype).expand(
matmul_result.size(0), -1, matmul_result.size(2), -1
),
],
dim=-1,
)
attention_mask = F.pad(attention_mask, (0, 1), mode="constant", value=False)
attn_mask_type = "arbitrary"
# attention scores and attention mask
softmax_scale = self.layer_number if apply_qk_layer_scaling else None
attention_probs = self.scale_mask_softmax(
matmul_result, attention_mask, attn_mask_type, softmax_scale
......@@ -337,6 +355,10 @@ class UnfusedDotProductAttention(torch.nn.Module):
if "padding" in attn_mask_type:
attention_probs = attention_probs.masked_fill(attention_mask, 0)
# remove attention sink: [b, np, sq, sk]
if self.softmax_type != "vanilla":
attention_probs = attention_probs[..., :-1]
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with self.attention_dropout_ctx():
......@@ -917,6 +939,7 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
softmax_type,
window_size,
rng_gen,
fused_attention_backend,
......@@ -925,6 +948,7 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_meta,
quantizers,
deterministic,
softmax_offset,
):
# pylint: disable=missing-function-docstring
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
......@@ -997,8 +1021,10 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
softmax_type,
window_size,
rng_gen,
softmax_offset,
)
if is_output_fp8:
out_ret = out_fp8
......@@ -1059,8 +1085,10 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
softmax_type,
window_size,
rng_gen,
softmax_offset,
)
out_save = out_ret
fp8_tensors = (None, None, None, None)
......@@ -1114,6 +1142,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.softmax_type = softmax_type
ctx.window_size = window_size
ctx.fused_attention_backend = (
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
......@@ -1224,6 +1253,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout,
ctx.attn_bias_type,
ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size,
ctx.deterministic,
)
......@@ -1287,42 +1317,17 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout,
ctx.attn_bias_type,
ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size,
ctx.deterministic,
)
# if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]:
return (
None,
None,
None,
None,
None,
None,
None,
None,
None,
dq,
dk,
dv,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
# else, return (dqkv, dbias)
d_bias = None
if ctx.attn_bias_type not in ["no_bias", "alibi"]:
d_bias = rest[0]
d_softmax_offset = None
if ctx.softmax_type != "vanilla":
d_softmax_offset = rest[1]
return (
None,
None,
......@@ -1336,7 +1341,8 @@ class FusedAttnFunc(torch.autograd.Function):
dq,
dk,
dv,
rest[0],
d_bias,
None,
None,
None,
None,
......@@ -1351,6 +1357,7 @@ class FusedAttnFunc(torch.autograd.Function):
None,
None,
None,
d_softmax_offset,
)
......@@ -1390,6 +1397,7 @@ class FusedAttention(torch.nn.Module):
attention_type: str = "self",
layer_number: Optional[int] = None,
deterministic: bool = False,
softmax_type: str = "vanilla",
) -> None:
super().__init__()
......@@ -1402,6 +1410,7 @@ class FusedAttention(torch.nn.Module):
) == "1" and get_device_compute_capability() == (9, 0)
self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic
self.softmax_type = softmax_type
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
......@@ -1453,6 +1462,7 @@ class FusedAttention(torch.nn.Module):
quantizers=None,
pad_between_seqs: bool = False,
inference_params: Optional[InferenceParams] = None,
softmax_offset: torch.Tensor = None,
) -> torch.Tensor:
"""fused attention fprop"""
assert (
......@@ -1603,6 +1613,8 @@ class FusedAttention(torch.nn.Module):
fp8_meta=fp8_meta,
quantizers=quantizers,
pad_between_seqs=pad_between_seqs,
softmax_type=self.softmax_type,
softmax_offset=softmax_offset,
)
else:
with self.attention_dropout_ctx():
......@@ -1626,6 +1638,7 @@ class FusedAttention(torch.nn.Module):
qkv_layout,
core_attention_bias_type,
attn_mask_type,
self.softmax_type,
window_size,
None, # rng_gen
fused_attention_backend,
......@@ -1634,6 +1647,7 @@ class FusedAttention(torch.nn.Module):
fp8_meta,
quantizers,
self.deterministic,
softmax_offset,
)
# ...hd -> ...(hd)
......
......@@ -46,6 +46,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
_cu_seqlens_info_with_cp_cache = {}
_seq_chunk_ids_cache_for_reordering_before_attn = {}
_seq_chunk_ids_cache_for_reordering_after_attn = {}
_softmax_offset_chunk_ids_cache = {}
def flash_attn_p2p_communicate(
......@@ -318,6 +319,55 @@ def flash_attn_a2a_communicate(
return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs
def flash_attn_a2a_communicate_softmax_offset(
tensor: torch.Tensor,
h_dim: int,
cp_size: int,
cp_group: dist_group_type,
cp_stream: torch.cuda.Stream,
before_attn: bool,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Split/AllGather communication for softmax offset."""
if tensor is None:
return None
global _softmax_offset_chunk_ids_cache
device = tensor.device
if (cp_size, device) not in _softmax_offset_chunk_ids_cache:
chunk_ids = torch.arange(cp_size, dtype=torch.int32, device=device)
_softmax_offset_chunk_ids_cache[(cp_size, device)] = chunk_ids
else:
chunk_ids = _softmax_offset_chunk_ids_cache[(cp_size, device)]
if before_attn:
# softmax_offset: split round-robin to CP ranks
# [1, h, 1, 1] -> [1, cp, h//cp, 1, 1]
shape = tensor.shape
tensor = tensor.view(
*shape[:h_dim], cp_size, shape[h_dim] // cp_size, *shape[(h_dim + 1) :]
)
rank = get_distributed_rank(cp_group)
output = torch.index_select(tensor, dim=h_dim, index=chunk_ids[rank])
output = output.view(*shape[:h_dim], -1, *shape[(h_dim + 1) :])
else:
# d_softmax_offset: all-gather from all ranks to all ranks
# [1, h//cp, 1, 1] -> [1, h, 1, 1]
inp = tensor.view(-1)
output = torch.empty(cp_size * inp.shape[0], dtype=tensor.dtype, device=device)
with torch.cuda.stream(cp_stream):
torch.distributed.all_gather_into_tensor(
output,
inp,
group=cp_group,
async_op=False,
)
torch.cuda.current_stream().wait_stream(cp_stream)
output = output.view(
*tensor.shape[:h_dim], cp_size * tensor.shape[h_dim], *tensor.shape[h_dim + 1 :]
)
return output
def _get_cu_seqlens_info_with_cp(
batch_size: int,
max_seqlen: int,
......@@ -1854,7 +1904,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
cu_seqlens_q_per_step[cp_size - i - 1],
......@@ -2014,7 +2064,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv // 2,
cu_seqlens_q_per_step[cp_size - i - 1],
......@@ -2171,7 +2221,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd(
ctx.max_seqlen_q // 2,
ctx.max_seqlen_kv,
cu_seqlens_q_per_step[cp_size - i - 1],
......@@ -2289,7 +2339,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i]
fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
cu_seqlens_q_per_step[cp_size - i - 1],
......@@ -3122,7 +3172,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape)
if ctx.use_fused_attention:
aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]]
dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd(
dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd(
ctx.max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q,
......@@ -3283,6 +3333,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cp_stream,
quantizers,
use_flash_attn_3,
softmax_type,
softmax_offset,
):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
......@@ -3391,6 +3443,10 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
q, k, v = flash_attn_a2a_communicate(
[q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
)
if softmax_type != "vanilla":
softmax_offset = flash_attn_a2a_communicate_softmax_offset(
softmax_offset, 1, cp_size, cp_group, cp_stream, True
)
if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_f16, k_f16, v_f16 = q, k, v
......@@ -3430,6 +3486,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
window_size=window_size,
**fp8_meta_kwargs,
softmax_type=softmax_type,
softmax_offset=softmax_offset,
)
if fp8:
out = out._data
......@@ -3532,6 +3590,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
ctx.use_flash_attn_3 = use_flash_attn_3
ctx.softmax_type = softmax_type
ctx.qkv_dtype = qkv_dtype
ctx.dQKV_quantizer = dQKV_quantizer
......@@ -3695,7 +3754,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
dout_part, fake_dtype=dout_dtype, internal=True
)
dq, dk, dv, _ = fused_attn_bwd(
dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
cu_seqlens_q,
......@@ -3719,6 +3778,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
window_size=ctx.window_size,
deterministic=ctx.deterministic,
**fp8_meta_kwargs,
softmax_type=ctx.softmax_type,
)
if ctx.fp8:
dq = dq._data
......@@ -3763,6 +3823,17 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
elif ctx.qkv_format == "sbhd":
dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]
d_bias = None
d_softmax_offset = None
if ctx.use_fused_attention:
if ctx.attn_bias_type not in ["no_bias", "alibi"]:
d_bias = rest[0]
if ctx.softmax_type != "vanilla":
d_softmax_offset = rest[1]
d_softmax_offset = flash_attn_a2a_communicate_softmax_offset(
d_softmax_offset, 1, cp_size, ctx.cp_group, ctx.cp_stream, False
)
if ctx.fp8:
dq = ctx.dQKV_quantizer.create_tensor_from_data(
dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8
......@@ -3793,6 +3864,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
None,
None,
None,
d_bias,
None,
None,
None,
......@@ -3803,6 +3875,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
None,
None,
None,
d_softmax_offset,
)
......@@ -3835,6 +3908,8 @@ def attn_forward_func_with_cp(
quantizers=None,
pad_between_seqs=False,
use_flash_attn_3=False,
softmax_type="vanilla",
softmax_offset=None,
) -> torch.Tensor:
"""
Attention implementation with context parallelism (CP). CP partitions tensors along the sequence
......@@ -3911,23 +3986,23 @@ def attn_forward_func_with_cp(
else:
assert isinstance(
cp_group, dist_group_type
), f"Unsupported process group for CP communication type {cp_comm_type}!"
), f"cp_group must be {dist_group_type} type for {cp_comm_type=}!"
assert qkv_format in [
"bshd",
"sbhd",
"thd",
], f"QKV format of {qkv_format} is not supported with context parallelism!"
], f"Context parallelism does not support {qkv_format=}!"
assert (
qkv_format != "sbhd" or use_fused_attention
), "FlashAttention does not support sbhd format!"
), "Context parallelism does not support FlashAttention backend with qkv_format = 'sbhd'!"
assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
"""Attention bias is only supported with FusedAttention and "causal" """
"""or "no_mask" mask types!"""
"Context parallelism only supports attention bias with FusedAttention backend and"
" non-padding mask types!"
)
assert qkv_format != "thd" or (
cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
), "cu_seqlens_padded cannot be None with context parallelism + THD format!"
), "cu_seqlens_padded can not be None for context parallelism and qkv_format = 'thd'!"
sliding_window_attn = (
window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
......@@ -3935,13 +4010,28 @@ def attn_forward_func_with_cp(
assert not sliding_window_attn or cp_comm_type in [
"a2a",
"all_gather",
], "The context parallel running configs cannot support sliding window attetnion!"
], "Context parallelism does not support sliding window attention with {cp_comm_type=}!"
enable_mla = k.shape[-1] != v.shape[-1]
assert not enable_mla or cp_comm_type in [
"p2p",
"a2a+p2p",
], "The context parallel running configs cannot support MLA!"
], "Context parallelism does not support MLA with {cp_comm_type=}!"
if fp8 and fp8_meta is not None:
if fp8_meta["recipe"].fp8_dpa:
assert (
softmax_type == "vanilla"
), "Context parallelism does not support {softmax_type=} with FP8 attention!"
assert (
softmax_type == "vanilla" or use_fused_attention
), "Context parallelism only supports {softmax_type=} with FusedAttention backend!"
assert (
softmax_type == "vanilla" or cp_comm_type == "a2a"
), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!"
assert (
softmax_type == "vanilla" or qkv_format != "thd"
), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!"
args = [
is_training,
......@@ -3982,7 +4072,17 @@ def attn_forward_func_with_cp(
args += [window_size, cp_group, cp_stream, use_flash_attn_3]
out = AttnFuncWithCPAndKVAllGather.apply(*args)
elif cp_comm_type == "a2a":
args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers, use_flash_attn_3]
args += [
window_size,
fp8,
fp8_meta,
cp_group,
cp_stream,
quantizers,
use_flash_attn_3,
softmax_type,
softmax_offset,
]
out = AttnFuncWithCPAndQKVOA2A.apply(*args)
else:
raise ValueError(f"Unsupported communication type: {cp_comm_type}!")
......
......@@ -11,6 +11,7 @@ import warnings
import logging
import torch
from torch.nn.parameter import Parameter
import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import get_cudnn_version
......@@ -168,6 +169,17 @@ class DotProductAttention(TransformerEngineBaseModule):
softmax_scale: Optional[float], default = `None`
softmax scale for the attention scores. If `None`, defaults to
`1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters
----------------------
......@@ -223,6 +235,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p",
softmax_scale: Optional[float] = None,
softmax_type: str = "vanilla",
) -> None:
super().__init__()
......@@ -307,6 +320,20 @@ class DotProductAttention(TransformerEngineBaseModule):
self.attention_type = attention_type
self.attention_dropout = attention_dropout
self.softmax_type = softmax_type
if self.softmax_type == "vanilla":
self.softmax_offset = None
if self.softmax_type == "off-by-one":
self.softmax_offset = torch.zeros(
self.num_attention_heads // self.tp_size, device="cuda"
)
if self.softmax_type == "learnable":
self.register_parameter(
"softmax_offset",
Parameter(torch.empty(self.num_attention_heads // self.tp_size, device="cuda")),
get_rng_state_tracker=get_rng_state_tracker,
)
attn_kwargs = {
"attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx,
......@@ -328,6 +355,7 @@ class DotProductAttention(TransformerEngineBaseModule):
layer_number=layer_number,
deterministic=self.deterministic,
**attn_kwargs,
softmax_type=self.softmax_type,
)
self.unfused_attention = UnfusedDotProductAttention(
......@@ -335,6 +363,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_type=attention_type,
**attn_kwargs,
layer_number=layer_number,
softmax_type=self.softmax_type,
)
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
......@@ -634,6 +663,7 @@ class DotProductAttention(TransformerEngineBaseModule):
query_layer,
num_gemms=3,
allow_non_contiguous=True,
allow_different_data_and_param_types=self.softmax_type != "vanilla",
) as query_layer:
# checks for RNG
if self.rng_states_tracker is not None and is_graph_capturing():
......@@ -922,6 +952,7 @@ class DotProductAttention(TransformerEngineBaseModule):
False
), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"
# check if there is padding between sequences when qkv_format='thd'
if pad_between_seqs is None:
if qkv_format == "thd":
pad_between_seqs = (
......@@ -957,11 +988,13 @@ class DotProductAttention(TransformerEngineBaseModule):
pad_between_seqs=pad_between_seqs,
attention_dropout=self.attention_dropout,
context_parallel=context_parallel,
cp_comm_type=self.cp_comm_type,
deterministic=self.deterministic,
is_training=self.training,
fp8=self.fp8,
fp8_meta=self.fp8_meta,
inference_params=inference_params,
softmax_type=self.softmax_type,
)
global _attention_backends
if is_in_onnx_export_mode():
......@@ -1022,6 +1055,12 @@ class DotProductAttention(TransformerEngineBaseModule):
)
# run attention
softmax_offset = (
self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32)
if self.softmax_offset is not None
else None
)
if use_flash_attention:
if core_attention_bias_type == "alibi":
alibi_slopes, _ = dpa_utils.get_alibi(
......@@ -1071,7 +1110,6 @@ class DotProductAttention(TransformerEngineBaseModule):
bias_dtype=query_layer.dtype,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
)
# checkpoint_core_attention=False
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.fused_attention,
......@@ -1101,6 +1139,7 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs,
inference_params=inference_params,
softmax_offset=softmax_offset,
)
return self.fused_attention(
query_layer,
......@@ -1129,6 +1168,7 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs,
inference_params=inference_params,
softmax_offset=softmax_offset,
)
from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled
......@@ -1157,6 +1197,7 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
inference_params=inference_params,
softmax_offset=softmax_offset,
)
return self.unfused_attention(
_alibi_cache,
......@@ -1173,5 +1214,6 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
inference_params=inference_params,
softmax_offset=softmax_offset,
)
return None
......@@ -24,6 +24,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
QKVLayout,
AttnBiasType,
AttnMaskType,
SoftmaxType,
FusedAttnBackend,
META_QKV,
META_DQKV,
......@@ -206,6 +207,8 @@ class AttentionParams:
Attention dropout.
context_parallel: bool, default = `False`
Whether context parallelism is used or not.
cp_comm_type: str, default = "p2p"
The communication type of context parallelism.
deterministic: bool, default = `False`
Whether to run `DotProductAttention` with determinism or not.
is_training: bool, default = `True`
......@@ -216,6 +219,8 @@ class AttentionParams:
The FP8 metadata tensor of `DotProductAttention`.
inference_params: Optional[InferenceParams], default = `None`
Inference-related parameters. See InferenceParams for details.
softmax_type: str, default = "vanilla"
The type of softmax operation. See DotProductAttention for details.
"""
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
......@@ -237,11 +242,13 @@ class AttentionParams:
pad_between_seqs: bool = False
attention_dropout: float = 0.0
context_parallel: bool = False
cp_comm_type: str = "p2p"
deterministic: bool = False
is_training: bool = True
fp8: bool = False
fp8_meta: Union[Dict[str, Any], None] = None
inference_params: Optional[InferenceParams] = None
softmax_type: str = "vanilla"
def __eq__(self, other):
"""
......@@ -308,11 +315,13 @@ def get_attention_backend(
pad_between_seqs = attention_params.pad_between_seqs
attention_dropout = attention_params.attention_dropout
context_parallel = attention_params.context_parallel
cp_comm_type = attention_params.cp_comm_type
deterministic = attention_params.deterministic
is_training = attention_params.is_training
fp8 = attention_params.fp8
fp8_meta = attention_params.fp8_meta
inference_params = attention_params.inference_params
softmax_type = attention_params.softmax_type
# Run config
logger = logging.getLogger("DotProductAttention")
......@@ -565,6 +574,51 @@ def get_attention_backend(
logger.debug("Disabling FlashAttention 3 for dropout")
use_flash_attention_3 = False
# Filter: Softmax type
# context_parallel | softmax_type | supported backends
# ----------------------------------------------------------------------------------------------------
# no | vanilla | All
# no | off-by-one | FusedAttention, UnfusedDotProductAttention
# no | learnable | FusedAttention, UnfusedDotProductAttention
# yes | vanilla | FusedAttention, FlashAttention
# yes | off-by-one | FusedAttention
# yes | learnable | FusedAttention
if softmax_type != "vanilla":
logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type)
use_flash_attention = False
if fp8 and fp8_meta["recipe"].fp8_dpa:
logger.debug("Disabling FusedAttention for softmax_type = %s in FP8", softmax_type)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type
)
use_unfused_attention = False
if qkv_format == "thd":
logger.debug(
"Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type
)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd",
softmax_type,
)
use_unfused_attention = False
if context_parallel:
logger.debug(
"Disabling UnfusedDotProductAttention for context parallelism with softmax_type"
" = %s",
softmax_type,
)
use_unfused_attention = False
if cp_comm_type != "a2a":
logger.debug(
"Disabling FusedAttention for context parallelism with softmax_type = %s and"
" cp_comm_type = %s",
softmax_type,
cp_comm_type,
)
use_fused_attention = False
# Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends
# ----------------------------------------------------------------------------------------------------
......@@ -806,6 +860,7 @@ def get_attention_backend(
QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
attention_dropout,
num_heads,
num_gqa_groups,
......
......@@ -135,6 +135,17 @@ class MultiheadAttention(torch.nn.Module):
For that, please use `get_qkv_layout` to gain the layout information.
name: str, default = `None`
name of the module, currently used for debugging purposes.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters
----------------------
......@@ -245,6 +256,7 @@ class MultiheadAttention(torch.nn.Module):
qk_norm_before_rope: bool = False,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
softmax_type: str = "vanilla",
) -> None:
super().__init__()
......@@ -262,6 +274,7 @@ class MultiheadAttention(torch.nn.Module):
self.return_bias = return_bias
self.cp_size = 1
self.cp_rank = 0
self.softmax_type = softmax_type
kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)
......@@ -416,6 +429,7 @@ class MultiheadAttention(torch.nn.Module):
tp_group=tp_group,
layer_number=self.layer_number,
attention_type=self.attention_type,
softmax_type=self.softmax_type,
)
# Linear
......
......@@ -12,6 +12,7 @@ from transformer_engine_torch import (
NVTE_QKV_Format,
NVTE_Bias_Type,
NVTE_Mask_Type,
NVTE_Softmax_Type,
NVTE_Fused_Attn_Backend,
)
from ..tensor.quantized_tensor import Quantizer
......@@ -86,6 +87,12 @@ AttnMaskType = {
"padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
}
SoftmaxType = {
"vanilla": NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX,
"off-by-one": NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX,
"learnable": NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX,
}
FusedAttnBackend = {
"F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen,
"F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
......@@ -131,8 +138,10 @@ def fused_attn_fwd(
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None,
softmax_offset: torch.Tensor = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for separate QKV input.
......@@ -197,6 +206,8 @@ def fused_attn_fwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
......@@ -205,6 +216,9 @@ def fused_attn_fwd(
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
softmax_offset: torch.Tensor, default = None
softmax offset tensor in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
Returns
----------
......@@ -286,6 +300,7 @@ def fused_attn_fwd(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size,
cu_seqlens_q,
cu_seqlens_kv,
......@@ -300,6 +315,7 @@ def fused_attn_fwd(
s_quantizer,
o_quantizer,
attn_bias,
softmax_offset,
rng_gen,
rng_elts_per_thread,
)
......@@ -333,6 +349,7 @@ def fused_attn_bwd(
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]:
......@@ -398,6 +415,8 @@ def fused_attn_bwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
......@@ -417,6 +436,9 @@ def fused_attn_bwd(
d_bias: torch.Tensor, optional
gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
or "post_scale_bias"; same data type and shape as Bias
d_softmax_offset: torch.Tensor, optional
gradient tensor of softmax offset in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
"""
if attn_scale is None:
d = q.size(-1)
......@@ -454,6 +476,7 @@ def fused_attn_bwd(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size,
deterministic,
cu_seqlens_q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment