Unverified Commit 87cb26c6 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Add max_logit support for MuonClip (#2195)



* add max_score for fused/unfused F16 non-CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* calculate max per head instead of max over all heads
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fused attn max_score shape
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert FE to github
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FE to 1.15.0-rc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* reduce ew kernels; fix causal masks; add more tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix to tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove logic for flash-attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: add CP support for p2p/a2a/all_gather
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor improvements of implementation/tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: add thd support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add thd to UnfusedDPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* more fixes for lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update to FE 1.15
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove unneeded changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable unfused for thd + pad_between_seqs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable thd for unfused until bug is fixed
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix all_gather
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix all gather
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rename max_score to max_logit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix all_gather
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix all_gather
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable fused attn + thd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 060811c9
Subproject commit 80a8e4af4d89d33a2c59d51fcf9fda1c9d368cd4
Subproject commit 0b1577c8c83401237d601d0d0db5210506705396
......@@ -248,6 +248,7 @@ def run_dpa_with_cp(
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit,
).cuda()
if config.softmax_type != "vanilla":
core_attn.softmax_offset.requires_grad = True
......@@ -308,6 +309,7 @@ def run_dpa_with_cp(
fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group)
else:
fp8_context = nullcontext()
max_logit = None
with fp8_context:
# q, k, v, out in FP8; dout in F16
out = core_attn(
......@@ -322,6 +324,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha,
)
if config.return_max_logit:
out, max_logit = out
if fp8_bwd and fp8_mha:
dout_fp8 = dout_quantizer(dout)
out.backward(dout_fp8)
......@@ -400,6 +404,7 @@ def run_dpa_with_cp(
fp8_context = nullcontext()
# run attention
max_logit_ = None
with fp8_context:
# q, k, v, out in FP8; dout in F16
out_ = core_attn(
......@@ -414,6 +419,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha,
)
if config.return_max_logit:
out_, max_logit_ = out_
if fp8_bwd and fp8_mha:
dout_fp8_ = dout_quantizer(dout_)
out_.backward(dout_fp8_)
......@@ -495,15 +502,15 @@ def run_dpa_with_cp(
)
atol, rtol, rmse_tol = get_tols(config, dtype)
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_]
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset]
names = ["out", "dq", "dk", "dv", "d_softmax_offset"]
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_logit_]
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_logit]
names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_logit"]
names_cp = [x + "_cp" for x in names]
names_no_cp = [x + "_no_cp" for x in names]
is_fp8 = dtype == "fp8"
for i, t in enumerate(tensors_no_cp):
if t is not None:
if "softmax_offset" not in names[i]:
if "softmax_offset" not in names[i] and "max_logit" not in names[i]:
if qkv_format == "bshd":
compare_and_assert(
t[:, 0],
......
......@@ -131,6 +131,11 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
config.attn_mask_type = (
"padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
)
# Get backends
is_training = True
......@@ -172,7 +177,7 @@ def test_dot_product_attention(
# UnfusedDotProductAttention backend
if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
unfused_attn_fwd, unfused_max_logit, unfused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"UnfusedDotProductAttention",
......@@ -186,7 +191,7 @@ def test_dot_product_attention(
# FusedAttention backend
if fused_attn_supported:
if len(fused_attn_backends) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
fused_attn_fwd, fused_max_logit, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
......@@ -198,7 +203,7 @@ def test_dot_product_attention(
)
if len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
......@@ -209,7 +214,7 @@ def test_dot_product_attention(
is_training,
)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
......@@ -222,7 +227,7 @@ def test_dot_product_attention(
# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
flash_attn_fwd, _, flash_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FlashAttention",
......@@ -243,6 +248,8 @@ def test_dot_product_attention(
if unfused_attn_supported and fused_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
if config.return_max_logit:
torch.testing.assert_close(fused_max_logit, unfused_max_logit, **tols)
for i, _ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported:
......@@ -266,6 +273,33 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
model_configs_max_logit = {
# test: ModelConfig(b, sq, hq, dqk)
"max_logit_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
"max_logit_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
"max_logit_3": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
"max_logit_4": ModelConfig(
8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias"
),
"max_logit_5": ModelConfig(
8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0)
),
"max_logit_6": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_max_logit])
@pytest.mark.parametrize("model", model_configs_max_logit.keys())
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"])
def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with checkpointing"""
config = model_configs[model]
config.return_max_logit = True
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
......@@ -962,6 +996,8 @@ def _run_dot_product_attention(
layout = layout.replace("d", "dqk")
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
# tensor: with padding tokens
# tensor_orig: without padding tokens
tensor_orig = tensor
if qkv_format == "thd" and pad_between_seqs:
tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
......@@ -1071,6 +1107,7 @@ def _run_dot_product_attention(
layer_number=1,
attention_type=config.attn_type,
softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit,
).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()
......@@ -1108,16 +1145,21 @@ def _run_dot_product_attention(
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
)
max_logit = None
if config.return_max_logit:
out, max_logit = out
if is_training:
out.backward(d_out)
d_softmax_offset = None
if is_training and config.softmax_type != "vanilla":
d_softmax_offset = block.softmax_offset.grad
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if is_training:
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, (None, None, None, d_softmax_offset)
return out, max_logit, (None, None, None, d_softmax_offset)
if backend == "FusedAttention":
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
......@@ -1146,14 +1188,18 @@ def _run_dot_product_attention(
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset)
return (
out_orig,
max_logit,
(q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset),
)
else:
return out_orig, (None, None, None, d_softmax_offset)
return out_orig, max_logit, (None, None, None, d_softmax_offset)
else:
if is_training:
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, (None, None, None, d_softmax_offset)
return out, max_logit, (None, None, None, d_softmax_offset)
model_configs_te_layer = {
......
......@@ -137,8 +137,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
model_configs_fused_attn = {
# test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_logit=True), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_logit=True), # MHA
"cp_1_2": ModelConfig(
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
), # MHA
......@@ -183,7 +183,7 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"]
......
......@@ -205,6 +205,7 @@ class ModelConfig:
window_size: Tuple[int, int] = (-1, -1),
context_parallel: bool = False,
cp_comm_type: str = "p2p",
return_max_logit=False,
total_requests: int = None,
max_ctx_len: int = None,
num_layers: int = 1,
......@@ -233,6 +234,7 @@ class ModelConfig:
self.window_size = check_set_window_size(self.attn_mask_type, window_size)
self.context_parallel = context_parallel
self.cp_comm_type = cp_comm_type
self.return_max_logit = return_max_logit
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
self.num_layers = num_layers
......@@ -318,6 +320,7 @@ def get_available_attention_backends(
is_training=is_training,
inference_params=inference_params,
softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit,
)
(
use_flash_attention,
......
......@@ -138,7 +138,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) {
int64_t window_size_right, bool return_max_logit) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
......@@ -187,7 +187,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
!requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) &&
// 9.10.0: known bugs with SDPA FP8
(cudnn_runtime_version != 91000)) {
(cudnn_runtime_version != 91000) && !return_max_logit) {
if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else {
......@@ -216,7 +216,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) &&
((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) &&
!requires_64bit_ragged_offset &&
(softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) {
(softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && !return_max_logit) {
flag_m512 = true;
}
if (
......@@ -418,8 +418,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
size_t max_seqlen, bool is_training, float attn_scale,
float dropout, NVTE_QKV_Layout qkv_layout,
size_t max_seqlen, bool is_training, bool return_max_logit,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace,
......@@ -460,7 +460,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -474,10 +474,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, input_Bias,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded,
input_rng_state, wkspace, stream, handle);
b, h, max_seqlen, d, t, is_training, return_max_logit, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV,
input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens,
input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -544,7 +544,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h,
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -602,7 +602,7 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
size_t max_seqlen_kv, bool is_training, bool return_max_logit, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) {
......@@ -680,7 +680,8 @@ void nvte_fused_attn_fwd_kvpacked(
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right,
return_max_logit);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -695,12 +696,12 @@ void nvte_fused_attn_fwd_kvpacked(
#if (CUDNN_VERSION >= 8903)
fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left,
window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, output_O,
Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state,
wkspace, stream, handle);
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training,
return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset,
output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k,
input_page_table_v, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -777,7 +778,7 @@ void nvte_fused_attn_bwd_kvpacked(
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -832,17 +833,15 @@ void nvte_fused_attn_bwd_kvpacked(
}
}
// NVTE fused attention FWD with separate Q, K and V
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
void nvte_fused_attn_fwd(
const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine;
......@@ -913,7 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right,
return_max_logit);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -928,12 +928,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left,
window_size_right, input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O,
Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state,
wkspace, stream, handle);
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training,
return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k,
input_page_table_v, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -1008,7 +1008,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......
......@@ -20,12 +20,13 @@ namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
bool is_training, bool return_max_logit, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens,
const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
......@@ -41,14 +42,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
......@@ -68,14 +70,14 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
......
......@@ -1710,7 +1710,8 @@ void fused_attn_fp8_fwd_impl_v1(
qkv_tensor_type,
o_tensor_type,
cudnn_frontend::DataType_t::NOT_SET,
cudnn_frontend::DataType_t::NOT_SET};
cudnn_frontend::DataType_t::NOT_SET,
false};
namespace fe = cudnn_frontend;
using graph_and_tensors =
......@@ -2038,7 +2039,8 @@ void fused_attn_fp8_bwd_impl_v1(
qkv_tensor_type,
o_tensor_type,
do_tensor_type,
dqkv_tensor_type};
dqkv_tensor_type,
false};
namespace fe = cudnn_frontend;
using graph_and_tensors =
......
......@@ -115,20 +115,21 @@ struct FADescriptor_v1 {
cudnn_frontend::DataType_t o_tensor_type;
cudnn_frontend::DataType_t do_tensor_type;
cudnn_frontend::DataType_t dqkv_tensor_type;
bool generate_max_sum_exp;
bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type,
o_tensor_type, do_tensor_type, dqkv_tensor_type) <
o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type,
rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type,
rhs.dqkv_tensor_type);
rhs.dqkv_tensor_type, rhs.generate_max_sum_exp);
}
};
......
......@@ -206,13 +206,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] head_dim_v The head dimension of V.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
*/
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right);
int64_t window_size_right, bool return_max_logit);
/*! \brief Compute dot product attention with packed QKV input.
*
......@@ -255,6 +256,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(seqlen_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
......@@ -266,13 +268,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen,
bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
size_t max_seqlen, bool is_training, bool return_max_logit,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
......@@ -381,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
......@@ -399,7 +405,7 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
size_t max_seqlen_kv, bool is_training, bool return_max_logit, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
......@@ -520,6 +526,7 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensors' layout.
......@@ -531,17 +538,15 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
void nvte_fused_attn_fwd(
const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
......
......@@ -22,7 +22,8 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy
auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false);
return backend;
}
......@@ -179,17 +180,18 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(),
nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, query_workspace_tensor.data(), nullptr);
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
......@@ -197,8 +199,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right,
kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, softmax_type, window_size_left, window_size_right,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
......@@ -276,7 +278,8 @@ static void FusedAttnForwardImpl(
auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
......@@ -294,7 +297,7 @@ static void FusedAttnForwardImpl(
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training,
q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
......@@ -308,8 +311,8 @@ static void FusedAttnForwardImpl(
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, softmax_type, window_size_left, window_size_right,
q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
......@@ -323,7 +326,7 @@ static void FusedAttnForwardImpl(
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, workspace_tensor.data(), stream);
} else {
......@@ -542,7 +545,8 @@ static void FusedAttnBackwardImpl(
auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias);
......
......@@ -58,6 +58,8 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
combine_and_quantize,
combine_and_dequantize,
print_quantizers,
ConvertTHDtoBSHD,
ConvertBSHDtoTHD,
)
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
AttentionLogging as attn_log,
......@@ -201,6 +203,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None,
softmax_type: str = "vanilla",
return_max_logit: Optional[bool] = False,
) -> None:
super().__init__()
......@@ -209,6 +212,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number
self.softmax_type = softmax_type
self.return_max_logit = return_max_logit
def mask_func(x, y):
return (
......@@ -217,6 +221,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
else attention_mask_func(x, y)
)
self.mask_func = mask_func
self.scale_mask_softmax = FusedScaleMaskSoftmax(mask_func)
# Dropout. Note that for a single iteration, this layer will generate
......@@ -238,6 +243,8 @@ class UnfusedDotProductAttention(torch.nn.Module):
qkv_layout: str = "sbh3d",
cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
max_seqlen_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
max_seqlen_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
window_size: Optional[Tuple[int, int]] = None,
......@@ -261,6 +268,9 @@ class UnfusedDotProductAttention(torch.nn.Module):
if inference_params is not None and inference_params.is_paged:
key_layer, value_layer = inference_params.convert_paged_to_nonpaged(self.layer_number)
# convert to sbhd
# training: bshd, thd
# inference: bshd, sbhd_2bshd, thd_2bshd
if qkv_format == "bshd":
# convert to sbhd and use sbhd implementation for now
query_layer, key_layer, value_layer = [
......@@ -269,9 +279,8 @@ class UnfusedDotProductAttention(torch.nn.Module):
if qkv_format == "sbhd_2bshd":
key_layer, value_layer = [x.transpose(0, 1) for x in [key_layer, value_layer]]
total_tokens, batch_size = None, None
if qkv_format == "thd_2bshd":
total_tokens, batch_size = query_layer.shape[0], key_layer.shape[0]
batch_size = key_layer.shape[0]
query_layer = tex.convert_thd_to_bshd(
query_layer,
cu_seqlens_q,
......@@ -281,6 +290,26 @@ class UnfusedDotProductAttention(torch.nn.Module):
query_layer, key_layer, value_layer = [
x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
]
if qkv_format == "thd":
assert cu_seqlens_q is not None and cu_seqlens_kv is not None
assert max_seqlen_q is not None and max_seqlen_kv is not None
query_layer = ConvertTHDtoBSHD.apply(
query_layer,
cu_seqlens_q,
max_seqlen_q,
)
key_layer, value_layer = [
ConvertTHDtoBSHD.apply(
x,
cu_seqlens_kv,
max_seqlen_kv,
)
for x in [key_layer, value_layer]
]
query_layer, key_layer, value_layer = [
x.transpose(0, 1).contiguous() for x in [query_layer, key_layer, value_layer]
]
batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[1],
query_layer.shape[0],
......@@ -426,6 +455,15 @@ class UnfusedDotProductAttention(torch.nn.Module):
matmul_result, None, None, dP_quantizer, "dP_quantizer", None
)
# max attention score
max_logit = None
if self.return_max_logit:
# matmul_result [b, np, sq, dk], max_logit [np]
max_logit = matmul_result
if attn_mask_type != "no_mask":
max_logit = self.mask_func(matmul_result, attention_mask)
max_logit = torch.amax(max_logit, dim=(0, 2, 3))
# add attention sink to the last column: [b, np, sq, sk+1]
if self.softmax_type != "vanilla":
matmul_result = torch.cat(
......@@ -506,14 +544,13 @@ class UnfusedDotProductAttention(torch.nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# [b, sq, np, hn] --> [tq, np, hn]
context_layer = tex.convert_bshd_to_thd(
context_layer = ConvertBSHDtoTHD.apply(
context_layer,
cu_seqlens_q,
total_tokens,
)
# [tq, np, hn] --> [tq, hp]
context_layer = context_layer.view(total_tokens, -1)
context_layer = context_layer.view(context_layer.shape[0], -1)
if fp8:
# quantize and dequantize O to emulate FP8
......@@ -529,6 +566,9 @@ class UnfusedDotProductAttention(torch.nn.Module):
if fp8_output:
context_layer = O_quantizer(context_layer)
if self.return_max_logit:
return context_layer, max_logit
return context_layer
......@@ -1067,6 +1107,7 @@ class FusedAttnFunc(torch.autograd.Function):
softmax_offset,
fp8_output,
layer_number,
return_max_logit,
):
# pylint: disable=missing-function-docstring
......@@ -1102,6 +1143,7 @@ class FusedAttnFunc(torch.autograd.Function):
# FP8 attention: torch.float16 or torch.bfloat16
out_nominal_dtype = q.dtype
max_logit = None
if fp8:
fused_attention_backend = FusedAttnBackend["FP8"]
......@@ -1129,7 +1171,7 @@ class FusedAttnFunc(torch.autograd.Function):
# DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
# Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_, aux_ctx_tensors = fused_attn_fwd(
out_, aux_ctx_tensors, *_ = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv,
......@@ -1205,7 +1247,7 @@ class FusedAttnFunc(torch.autograd.Function):
qkvo_tensors = (q, k, v, out)
else:
# q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_, aux_ctx_tensors = fused_attn_fwd(
out_, aux_ctx_tensors, *max_logit = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv,
......@@ -1233,6 +1275,7 @@ class FusedAttnFunc(torch.autograd.Function):
window_size,
rng_gen,
softmax_offset,
return_max_logit,
)
out = out_
out_ret = out_
......@@ -1327,10 +1370,12 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.use_FAv2_bwd = use_FAv2_bwd
ctx.deterministic = deterministic
if return_max_logit:
return out_ret, *max_logit
return out_ret
@staticmethod
def backward(ctx, d_out):
def backward(ctx, d_out, *_args):
# pylint: disable=missing-function-docstring
# d_out is expected to be in FP8 if is_output_fp8=True,
......@@ -1574,6 +1619,7 @@ class FusedAttnFunc(torch.autograd.Function):
d_softmax_offset,
None,
None,
None,
)
......@@ -1614,6 +1660,7 @@ class FusedAttention(torch.nn.Module):
layer_number: Optional[int] = None,
deterministic: bool = False,
softmax_type: str = "vanilla",
return_max_logit: Optional[bool] = False,
) -> None:
super().__init__()
......@@ -1627,6 +1674,7 @@ class FusedAttention(torch.nn.Module):
self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic
self.softmax_type = softmax_type
self.return_max_logit = return_max_logit
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
......@@ -1846,6 +1894,7 @@ class FusedAttention(torch.nn.Module):
softmax_offset=softmax_offset,
fp8_output=fp8_output,
layer_number=self.layer_number,
return_max_logit=self.return_max_logit,
)
else:
with self.attention_dropout_ctx():
......@@ -1881,7 +1930,11 @@ class FusedAttention(torch.nn.Module):
softmax_offset,
fp8_output,
self.layer_number,
self.return_max_logit,
)
if self.return_max_logit:
# ...hd -> ...(hd)
return output[0].view(*output[0].shape[:-2], -1), output[1]
# ...hd -> ...(hd)
return output.view(*output.shape[:-2], -1)
......@@ -617,6 +617,7 @@ def cp_p2p_fwd_fused_attn(
rank,
step,
cp_size,
return_max_logit,
q_part,
k_part,
v_part,
......@@ -693,7 +694,7 @@ def cp_p2p_fwd_fused_attn(
fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step
fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step
out_per_step, aux_ctx_tensors = fused_attn_fwd(
out_per_step, aux_ctx_tensors, *max_logit = fused_attn_fwd(
is_training,
max_seqlen_q_,
max_seqlen_kv_,
......@@ -713,6 +714,7 @@ def cp_p2p_fwd_fused_attn(
cu_seqlens_q_padded=cu_seqlens_q_padded_,
cu_seqlens_kv_padded=cu_seqlens_kv_padded_,
**fp8_meta_kwargs,
return_max_logit=return_max_logit,
)
if fp8:
......@@ -721,7 +723,9 @@ def cp_p2p_fwd_fused_attn(
softmax_lse_per_step, rng_states, *rest = aux_ctx_tensors
attn_bias = rest[0] if len(rest) > 0 else None
return out_per_step, softmax_lse_per_step, rng_states, attn_bias
if return_max_logit:
return out_per_step, softmax_lse_per_step, rng_states, attn_bias, *max_logit
return out_per_step, softmax_lse_per_step, rng_states, attn_bias, None
def cp_p2p_fwd_flash_attn(
......@@ -1086,6 +1090,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
attn_bias,
deterministic,
use_fused_attention,
return_max_logit,
fp8,
fp8_meta,
cp_group,
......@@ -1156,6 +1161,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
amax_per_step = None
S_quantizer_per_step = [None for _ in range(cp_size)]
O_quantizer_per_step = [None for _ in range(cp_size)]
max_logit_per_step = [None for _ in range(cp_size)]
max_logit = None
assert isinstance(k, q.__class__) and isinstance(
v, q.__class__
......@@ -1244,6 +1251,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_f16 = q
if use_fused_attention:
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
if return_max_logit:
max_logit_per_step = [
torch.empty(q.shape[-2], dtype=q.dtype, device=q.device) for _ in range(cp_size)
]
# split qkv to two halves and prepare for load balancing
assert qkv_format == "thd" or (
......@@ -1418,6 +1429,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
rank,
i,
cp_size,
return_max_logit,
]
else:
flash_attn_inputs = [
......@@ -1462,6 +1474,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i],
rng_states[i],
attn_biases[i],
max_logit_per_step[i],
) = cp_p2p_fwd_fused_attn(
*fused_attn_inputs, *prepare_outputs, section
)
......@@ -1488,6 +1501,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i],
rng_states[i],
attn_biases[i],
max_logit_per_step[i],
) = cp_p2p_fwd_fused_attn(
*fused_attn_inputs, *prepare_outputs, section
)
......@@ -1514,6 +1528,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i],
rng_states[i],
attn_biases[i],
max_logit_per_step[i],
) = cp_p2p_fwd_fused_attn(
*fused_attn_inputs, *prepare_outputs, section
)
......@@ -1541,6 +1556,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i],
rng_states[i],
attn_biases[i],
max_logit_per_step[i],
) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section)
else:
out_per_step[i], softmax_lse_per_step[i], rng_states[i] = (
......@@ -1600,11 +1616,20 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse.view(*softmax_lse.shape[:-1], 2, -1),
softmax_lse_per_step[i - 1],
)
if return_max_logit:
if i == 1:
max_logit = torch.clone(max_logit_per_step[0])
else:
max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1])
if i < cp_size:
flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done)
torch.cuda.current_stream().wait_stream(flash_attn_streams[1])
if return_max_logit:
torch.distributed.all_reduce(
max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group
)
second_half_lse_seqlen = None
if causal and rank < (cp_size - 1):
......@@ -1682,6 +1707,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif qkv_format == "sbhd":
# [s*b, h, d] -> [s, b, h, d]
out = out.view(-1, ctx.batch_size, *out.shape[-2:])
if return_max_logit:
max_logit = flash_attn_a2a_communicate_softmax_offset(
max_logit, 0, cp_size_a2a, cp_group_a2a, cp_stream, False
)
elif not use_fused_attention:
out = out.view(-1, *out.shape[-2:])
......@@ -1811,10 +1840,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}")
if return_max_logit:
return out_ret, max_logit
return out_ret
@staticmethod
def backward(ctx, dout):
def backward(ctx, dout, *_args):
# pylint: disable=missing-function-docstring
# add NVTX range
......@@ -2522,6 +2553,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -2577,6 +2609,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
attn_bias,
deterministic,
use_fused_attention,
return_max_logit,
window_size,
cp_group,
cp_stream,
......@@ -2682,6 +2715,8 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
softmax_lse_per_step = [None, None]
rng_states = [None, None]
out = torch.empty_like(q)
max_logit_per_step = [None, None]
max_logit = None
for i in range(len(local_seq_chunk_ids) + 1):
if i < len(local_seq_chunk_ids):
......@@ -2712,7 +2747,11 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
# [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d]
k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
if use_fused_attention:
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd(
(
out_per_step[i],
[softmax_lse_per_step[i], rng_states[i]],
*max_logit_,
) = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv_,
......@@ -2732,7 +2771,10 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
window_size=window_size_per_step[i],
return_max_logit=return_max_logit,
)
if return_max_logit:
max_logit_per_step[i] = max_logit_[0]
else:
fa_forward_args_thd = get_fa_args(
True,
......@@ -2767,14 +2809,22 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
if not use_flash_attn_3:
rng_states[i] = fa_outputs[3]
if return_max_logit and i == 0:
max_logit = torch.clone(max_logit_per_step[0])
if i > 0:
with torch.cuda.stream(flash_attn_streams[i - 1]):
if qkv_format == "bshd":
out[:, i - 1].copy_(out_per_step[i - 1])
elif qkv_format == "sbhd":
out[i - 1].copy_(out_per_step[i - 1])
if return_max_logit:
max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1])
torch.cuda.current_stream().wait_stream(cp_stream)
if return_max_logit:
torch.distributed.all_reduce(
max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group
)
if use_fused_attention:
if qkv_format == "bshd":
......@@ -2811,10 +2861,12 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
ctx.use_fused_attention = use_fused_attention
ctx.use_flash_attn_3 = use_flash_attn_3
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
if return_max_logit:
return out, max_logit
return out
@staticmethod
def backward(ctx, dout):
def backward(ctx, dout, *_args):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
cp_size = get_distributed_world_size(ctx.cp_group)
......@@ -3035,6 +3087,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -3065,6 +3118,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
attn_bias,
deterministic,
use_fused_attention,
return_max_logit,
window_size,
fp8,
fp8_meta,
......@@ -3158,6 +3212,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
fp8_recipe = fp8_meta["local_recipes"][0]
fwd_nominal_dtype = q.dtype
fused_attn_backend = None
max_logit = None
QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
dpa_utils.get_attention_quantizers(fp8, quantizers)
......@@ -3203,7 +3258,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype)
for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part])
]
out_, aux_ctx_tensors = fused_attn_fwd(
out_, aux_ctx_tensors, *max_logit = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv,
......@@ -3226,6 +3281,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
**fp8_meta_kwargs,
softmax_type=softmax_type,
softmax_offset=softmax_offset,
return_max_logit=return_max_logit,
)
if isinstance(out_, Float8Tensor):
out_fp8 = out_
......@@ -3276,6 +3332,10 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
out_ = flash_attn_a2a_communicate(
out_, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False
)
if return_max_logit:
max_logit = flash_attn_a2a_communicate_softmax_offset(
*max_logit, 0, cp_size, cp_group, cp_stream, False
)
if use_fused_attention:
if qkv_format == "bshd":
......@@ -3362,10 +3422,12 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.S_quantizer = S_quantizer.copy()
ctx.S_quantizer.scale = S_quantizer.scale.clone()
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
if return_max_logit:
return out_ret, max_logit
return out_ret
@staticmethod
def backward(ctx, dout):
def backward(ctx, dout, *_args):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
cp_size = get_distributed_world_size(ctx.cp_group)
......@@ -3599,6 +3661,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
None,
None,
None,
None,
d_softmax_offset,
None,
)
......@@ -3637,6 +3700,7 @@ def attn_forward_func_with_cp(
softmax_offset=None,
fp8_output=False,
layer_number=1,
return_max_logit=False,
) -> torch.Tensor:
"""
Attention implementation with context parallelism (CP). CP partitions tensors along the sequence
......@@ -3784,6 +3848,7 @@ def attn_forward_func_with_cp(
attn_bias,
deterministic,
use_fused_attention,
return_max_logit,
]
if cp_comm_type in ["p2p", "a2a+p2p"]:
......
......@@ -255,6 +255,12 @@ class DotProductAttention(TransformerEngineBaseModule):
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
return_max_logit: Optional[bool], default = `False`
If true, returns the maximum attention score that can be used in a Muon optimizer to
rescale the Q and K projection weights (see `Muon is Scalable for LLM Training
<https://arxiv.org/pdf/2502.16982>`_).
max_logit = max(S), where S = mask(Q*K^T*softmax_scale + bias) in shape [b, h, s_q, s_kv],
and max_logit is in shape [h].
Parallelism parameters
----------------------
......@@ -311,6 +317,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_comm_type: str = "p2p",
softmax_scale: Optional[float] = None,
softmax_type: str = "vanilla",
return_max_logit: Optional[bool] = False,
) -> None:
super().__init__()
......@@ -394,6 +401,7 @@ class DotProductAttention(TransformerEngineBaseModule):
self.attention_type = attention_type
self.attention_dropout = attention_dropout
self.return_max_logit = return_max_logit
self.softmax_type = softmax_type
if self.softmax_type == "vanilla":
......@@ -431,6 +439,7 @@ class DotProductAttention(TransformerEngineBaseModule):
deterministic=self.deterministic,
**attn_kwargs,
softmax_type=self.softmax_type,
return_max_logit=self.return_max_logit,
)
self.unfused_attention = UnfusedDotProductAttention(
......@@ -439,6 +448,7 @@ class DotProductAttention(TransformerEngineBaseModule):
**attn_kwargs,
layer_number=layer_number,
softmax_type=self.softmax_type,
return_max_logit=self.return_max_logit,
)
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
......@@ -1303,6 +1313,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8_meta=self.fp8_meta,
inference_params=inference_params,
softmax_type=self.softmax_type,
return_max_logit=self.return_max_logit,
)
global _attention_backends
if is_in_onnx_export_mode():
......@@ -1502,6 +1513,8 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
......@@ -1523,6 +1536,8 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
......
......@@ -229,6 +229,8 @@ class AttentionParams:
Inference-related parameters. See InferenceParams for details.
softmax_type: str, default = "vanilla"
The type of softmax operation. See DotProductAttention for details.
return_max_logit: bool, default = `False`
Whether to output max_logit.
"""
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
......@@ -257,6 +259,7 @@ class AttentionParams:
fp8_meta: Union[Dict[str, Any], None] = None
inference_params: Optional[InferenceParams] = None
softmax_type: str = "vanilla"
return_max_logit: bool = False
def __eq__(self, other):
"""
......@@ -330,6 +333,7 @@ def get_attention_backend(
fp8_meta = attention_params.fp8_meta
inference_params = attention_params.inference_params
softmax_type = attention_params.softmax_type
return_max_logit = attention_params.return_max_logit
# Run config
logger = logging.getLogger("DotProductAttention")
......@@ -477,6 +481,20 @@ def get_attention_backend(
logger.debug("Disabling FusedAttention for FP8 current scaling with cuDNN < 9.14.0")
use_fused_attention = False
# Filter: Return max_logit
if return_max_logit:
if use_flash_attention:
use_flash_attention = False
logger.debug("Disabling FlashAttention for max_logit")
if use_fused_attention and qkv_format == "thd":
use_fused_attention = False
logger.debug("Disabling FusedAttention for max_logit with qkv_format = thd")
if fp8 and fp8_meta["recipe"].fp8_dpa:
use_flash_attention = False
use_fused_attention = False
use_unfused_attention = False
logger.debug("Disabling all backends for max_logit with FP8 attention")
# Filter: KV cache
# backend | precision | KV cache | architecture | qkv_format | page_size
# ---------------------------------------------------------------------------------------
......@@ -913,6 +931,7 @@ def get_attention_backend(
head_dim_v,
window_size[0],
window_size[1],
return_max_logit,
)
if fused_attention_backend == FusedAttnBackend["No_Backend"]:
logger.debug("Disabling FusedAttention as no backend supports the provided input")
......@@ -1649,6 +1668,78 @@ class UnpackTensor(torch.autograd.Function):
return None, None, _pack_tensor(indices, grad_output)
class ConvertTHDtoBSHD(torch.autograd.Function):
"""
Convert a tensor from qkv_format = thd to qkv_format = bshd.
"""
@staticmethod
def forward(ctx, thd_tensor, cu_seqlens, max_seqlen):
# pylint: disable=missing-function-docstring
batch_size = cu_seqlens.shape[0] - 1
if not thd_tensor.is_contiguous():
thd_tensor = thd_tensor.contiguous()
bshd_tensor = tex.convert_thd_to_bshd(
thd_tensor,
cu_seqlens,
batch_size,
max_seqlen,
)
ctx.save_for_backward(cu_seqlens)
ctx.num_tokens = thd_tensor.shape[0]
return bshd_tensor
@staticmethod
def backward(ctx, bshd_tensor):
# pylint: disable=missing-function-docstring
(cu_seqlens,) = ctx.saved_tensors
if not bshd_tensor.is_contiguous():
bshd_tensor = bshd_tensor.contiguous()
thd_tensor = tex.convert_bshd_to_thd(
bshd_tensor,
cu_seqlens,
ctx.num_tokens,
)
return thd_tensor, None, None
class ConvertBSHDtoTHD(torch.autograd.Function):
"""
Convert a tensor from qkv_format = bshd to qkv_format = thd.
"""
@staticmethod
def forward(ctx, bshd_tensor, cu_seqlens):
# pylint: disable=missing-function-docstring
num_tokens = cu_seqlens[-1]
max_seqlen = bshd_tensor.shape[1]
if not bshd_tensor.is_contiguous():
bshd_tensor = bshd_tensor.contiguous()
thd_tensor = tex.convert_bshd_to_thd(
bshd_tensor,
cu_seqlens,
num_tokens,
)
ctx.save_for_backward(cu_seqlens)
ctx.max_seqlen = max_seqlen
return thd_tensor
@staticmethod
def backward(ctx, thd_tensor):
# pylint: disable=missing-function-docstring
(cu_seqlens,) = ctx.saved_tensors
batch_size = cu_seqlens.shape[0] - 1
if not thd_tensor.is_contiguous():
thd_tensor = thd_tensor.contiguous()
bshd_tensor = tex.convert_thd_to_bshd(
thd_tensor,
cu_seqlens,
batch_size,
ctx.max_seqlen,
)
return bshd_tensor, None
def get_qkv_format(
qkv_layout: str = "bshd_bshd_bshd",
inference_params: InferenceParams = None,
......
......@@ -139,6 +139,7 @@ def fused_attn_fwd(
window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None,
softmax_offset: torch.Tensor = None,
return_max_logit: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for separate QKV input.
......@@ -216,6 +217,8 @@ def fused_attn_fwd(
softmax_offset: torch.Tensor, default = None
softmax offset tensor in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
return_max_logit: bool, default = False
whether to return the maximum attention score
Returns
----------
......@@ -246,6 +249,7 @@ def fused_attn_fwd(
rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen
state of the random number generator;
[seed, offset], dtype uint64
max_logit: if return_max_logit = True, shape [h] and same data type as O; otherwise None
"""
if attn_scale is None:
......@@ -315,8 +319,22 @@ def fused_attn_fwd(
softmax_offset,
rng_gen,
rng_elts_per_thread,
return_max_logit,
)
if return_max_logit:
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
# thd: output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1]
# bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
# sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
stats = output_tensors[1] + torch.log(output_tensors[2])
amax_dims = (0, 2) if qkv_format == "thd" else (0, 2, 3)
# Max -> max_logit [h]
max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype)
aux_ctx_tensors = [stats]
aux_ctx_tensors.extend(output_tensors[3:])
return output_tensors[0], aux_ctx_tensors, max_logit
# out, aux_ctx_tensors
return output_tensors[0], output_tensors[1:]
......
......@@ -76,7 +76,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right);
int64_t window_size_right, bool return_max_logit);
std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer,
const std::vector<size_t> &shape, DType dtype,
......@@ -94,7 +94,7 @@ std::vector<py::object> fused_attn_fwd(
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
size_t rng_elts_per_thread, bool return_max_logit);
std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
......
......@@ -45,11 +45,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) {
int64_t window_size_right, bool return_max_logit) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right);
max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right,
return_max_logit);
return fused_attention_backend;
}
......@@ -106,7 +107,7 @@ std::vector<py::object> fused_attn_fwd(
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread) {
size_t rng_elts_per_thread, bool return_max_logit) {
auto none = py::none();
// create QKV tensor wrappers
......@@ -228,8 +229,9 @@ std::vector<py::object> fused_attn_fwd(
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream());
return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
softmax_type, window_size[0], window_size[1], workspace.data(),
at::cuda::getCurrentCUDAStream());
});
// allocate memory for workspace and auxiliary output tensors
......@@ -249,7 +251,9 @@ std::vector<py::object> fused_attn_fwd(
};
// allocate memory for nvte_aux_tensor_pack.tensors
// f16_max512 : S [b, h, sq, skv]
// f16_arbitrary: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// f16_arbitrary:
// return_max_logit=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// return_max_logit=true: Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2]
size_t i = 0;
at::Tensor output_tensor;
......@@ -258,8 +262,8 @@ std::vector<py::object> fused_attn_fwd(
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
set_tensor_param(i++, output_tensor);
// fp8 has an additional softmax stats tensor, ZInv
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// fp8 has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Sum_Exp tensor
if (return_max_logit || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
output_tensor =
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
......@@ -285,8 +289,9 @@ std::vector<py::object> fused_attn_fwd(
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream());
return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
softmax_type, window_size[0], window_size[1], workspace.data(),
at::cuda::getCurrentCUDAStream());
});
// destroy tensor wrappers, but not allocated memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment