Unverified Commit 83a4c219 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[C/PyTorch] Add FP8 DPA and MHA (#768)



* WIP: fp8 v1 fprop integration
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add more debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fprop working for h1; w/ debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: add bprop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* cleanup; bprop running but has mismatches
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add gitlab frontend as submodule
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up and add back v0.9.2 FE support; fprop/bprop passing with 5e-2 tols
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix after merge; add bias_b/h to caching descriptor
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* distinguish fwd/bwd tensor types for bprop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for F16 cases; include added dqkv_type and d_scale_dp
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* adjust out shape for bwd in test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add casting from/to FP8 to DPA module
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: bshd_bshd_bshd layout
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: support all sbhd/bshd layouts
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add qkvpacked and kvpacked support in both FusedAttnFunc and C levels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove qkvpacked/kvpacked calls in DPA module (used for testing)
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove tp setup; add allow_non_contiguous; update FE; revert to sbh3d in tests; clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add NVTE_FP8_DPA_BWD to control whether to use FP8 bwd or F16 bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix MQA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix MQA/GQA in FP8 v1 API
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to 705d8e3, with API change
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test causal mask
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* restrict mha_fill for THD format
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fused attn with CP and comment out is_alibi code
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up FE0.9 vs FE1.0 FP8 implementations, and related unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change NVTE_FP8_DPA_BWD default to 1, and fix its use in qkvpacked/kvpacked APIs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint and self.tp_size/group in FusedAttention()
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to 6902c94
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add FP8 MHA support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update to FE v1.3.0
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for FP8 MHA with different configs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* emit stats regardless of is_training
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix linear when input is not Float8Tensor
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix d_out type when f16 bprop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix user buffer for layernorm_linear/linear and revert two FP8 casts in MHA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add docstring for fp8_dpa/mha in recipe
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix backend selection to avoid FA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace transpose with transpose_2d
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use RMSE for FP8 unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace two more transpose with transpose_2d
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add FP8 initialization to FusedAttention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rm docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Revert "add FP8 initialization to FusedAttention"

This reverts commit 15fffd825d6f23f31ea709b16ba01dfd61efabf8.
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change order of ctxs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back docs and mark as beta
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for tests and docs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f69e45be
Subproject commit a86ad708db725e4d29919bb6fadf8e6cdfa5dc06
Subproject commit 1b0b5eac540b7f8fd19b18f1e6b8427c95503348
......@@ -6,7 +6,7 @@ set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5 onnxruntime==1.13.1
pip install pytest==7.2 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
......
......@@ -2,6 +2,7 @@
#
# See LICENSE for license information.
import math
import functools
from importlib.metadata import version
import os
......@@ -12,9 +13,10 @@ import pytest
import torch
from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast
from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init
from transformer_engine.pytorch.attention import (
DotProductAttention,
MultiheadAttention,
RotaryPositionEmbedding,
)
from transformer_engine.pytorch.constants import TE_DType
......@@ -939,52 +941,415 @@ def _run_transformer_layer(
return out, inp.grad
model_configs_fp8 = {
model_configs_fp8_vs_f16 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_1": ModelConfig(1, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_9 ": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_10": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_11": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"),
"fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),
}
param_types_fp8 = [torch.float16]
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
qkv_layout_fp8_vs_f16 = ['sbh3d', 'bshd_bshd_bshd', 'sbhd_sbhd_sbhd']
qkv_format_fp8_vs_f16 = ['bshd', 'sbhd']
def _rmse(a, b):
return math.sqrt((torch.pow((a-b), 2)/a.numel()).sum())
@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", model_configs_fp8.keys())
def test_dpa_fp8(dtype, model):
"""Test FP8 dot product attention
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd):
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fp8_vs_f16[model]
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
if _NVTE_DEBUG:
print()
print("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm)
if _NVTE_DEBUG:
print()
print("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm)
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(),
fused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(),
fused_attn_fwd_f16.min().item())
if _NVTE_DEBUG:
print()
print('========== {:^25s} =========='.format('forward output'))
print('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item()))
print('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()))
print('fused_attn_fwd RMSE: {:.6f}'.format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
except Exception as e:
print(e)
print()
assert(fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range)
for i in range(len(param_names[:1])):
bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])
bwd_range = max(fused_attn_bwd_fp8[i].max().item(),
fused_attn_bwd_f16[i].max().item()) - min(fused_attn_bwd_fp8[i].min().item(),
fused_attn_bwd_f16[i].min().item())
if _NVTE_DEBUG:
print()
print('========== {:^25s} =========='.format(param_names[i]))
print('fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}'.format(i,
fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()))
print('fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}'.format(i,
fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()))
print('fused_attn_bwd RMSE[{}]: {:.6f}'.format(i, bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e:
print(e)
print()
assert(bwd_rmse < rmse_tol * bwd_range
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range)
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
FusedAttention uses fused_attn_fwd/bwd_qkvpacked from cpp_extensions,
and UnfusedDotProductAttention uses plain PyTorch operations in FP16
and converts inputs/outputs from/to FP8.
fp8_recipe = recipe.DelayedScaling(
margin=0,
interval=1,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=fp8_mha,
fp8_mha=fp8_mha,
)
"""
with fp8_model_init(enabled=fp8_mha):
mha = (MultiheadAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_heads,
kv_channels=config.head_dim,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
layer_number=1,
bias=True,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
input_layernorm=input_layernorm,
fuse_qkv_params=True,
attention_type="self",
qkv_weight_interleaved=True,
qkv_format=qkv_format,
).to(dtype=dtype, device="cuda")
)
config = model_configs_fp8[model]
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda")
seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
dtype=torch.int32, device="cuda")
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
# Skip if not supported
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype)
if not fused_attn_supported:
pytest.skip("FusedAttention does not support this model config")
dim_to_num = {
'b' : config.batch_size,
'sq' : config.max_seqlen_q,
'skv': config.max_seqlen_kv,
'h' : config.num_heads,
'hg' : config.num_gqa_groups,
'd' : config.head_dim,
't' : cu_seqlens_q[-1],
'tg' : cu_seqlens_kv[-1],
'3' : 3,
'2' : 2,
'1' : 1,
}
layout = '_'.join(qkv_format)
layout = layout.replace('s', 'sq')
tensor_shape = [dim_to_num[j] for j in layout.split('_')]
tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
hidden_states = tensor.view(*tensor.shape[:-2], -1)
hidden_states.requires_grad = True
tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
out_grad = tensor.view(*tensor.shape[:-2], -1)
with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe):
out = mha(hidden_states,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=None,
)
out.backward(out_grad)
# Run dot-product attention with different backends
fused_attn_fwd, fused_attn_bwd = _run_dpa_fp8(
dtype, config, "FusedAttention")
unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref(
dtype, config, "UnfusedDotProductAttention")
param_names = []
param_names.append('hidden_states.grad')
params = []
params.append(hidden_states)
for name, param in mha.named_parameters():
if param.requires_grad:
param_names.append(name+'.grad')
params.append(param)
tols = dict(atol=2.5e-2, rtol=2.5e-2)
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
return out, param_names, tuple(x.grad for x in params)
def _run_dpa_fp8(dtype, config, backend):
"""Run FusedAttention FP8 backend, i.e.
fused_attn_fwd/bwd_qkvpacked from cpp_extensions"""
@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
config = model_configs_fp8_vs_f16[model]
if (config.num_heads != config.num_gqa_groups and '3' in qkv_layout):
pytest.skip("qkv_layout not applicable for MQA/GQA");
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
if _NVTE_DEBUG:
print()
print("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout)
if _NVTE_DEBUG:
print("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False")
fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
dtype, config, False, qkv_layout)
tols = dict(atol=5e-1, rtol=5e-2)
if _NVTE_DEBUG:
print('[test_dpa_fp8_vs_f16]: ', tols)
print('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item()))
print('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()))
print('fused_attn_fwd RMSE: {:.6f}'.format(
_rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)))
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
for i,_ in enumerate(fused_attn_bwd_f16):
if _NVTE_DEBUG:
print('fused_attn_bwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()))
print('fused_attn_bwd_f16 min {:.6f} max {:.6f}'.format(
fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()))
print('fused_attn_bwd RMSE: {:.6f}'.format(
_rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])))
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
fp8_recipe = recipe.DelayedScaling(
margin=0,
interval=1,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=fp8_dpa,
)
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
with fp8_model_init(enabled=fp8_dpa):
dpa = (
DotProductAttention(
config.num_heads,
config.head_dim,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type="self",
qkv_format=qkv_format,
).to(dtype=dtype, device="cuda")
)
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda")
seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
dtype=torch.int32, device="cuda")
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
dim_to_num = {
'b' : config.batch_size,
'sq' : config.max_seqlen_q,
'skv': config.max_seqlen_kv,
'h' : config.num_heads,
'hg' : config.num_gqa_groups,
'd' : config.head_dim,
't' : cu_seqlens_q[-1],
'tg' : cu_seqlens_kv[-1],
'3' : 3,
'2' : 2,
'1' : 1,
}
inp = []
for i,layout in enumerate(qkv_layout.split('_')):
layout = '_'.join(layout)
if i == 0:
layout = layout.replace('s', 'sq')
else:
layout = layout.replace('s', 'skv')
layout = layout.replace('h', 'hg')
layout = layout.replace('t', 'tg')
tensor_shape = [dim_to_num[j] for j in layout.split('_')]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_count = 1
split_dim = 0
for dim, l in enumerate(layout.split('_')):
if l.isdigit():
tensor_count = int(l)
split_dim = dim
break
tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
for j in range(tensor_count):
if split_dim != 0:
inp.append(tensors[j].squeeze(split_dim))
else:
inp.append(tensors[j])
for i in range(3):
inp[i].requires_grad = True
qkv_format_kv = '_'.join(qkv_format)
qkv_format_kv = qkv_format_kv.replace('s', 'sq')
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = 0.1 * torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
with fp8_autocast(enabled=fp8_dpa, fp8_recipe=fp8_recipe):
out = dpa(inp[0], inp[1], inp[2],
qkv_format=qkv_format,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=True,
)
out.backward(out_grad)
return out, (inp[0].grad, inp[1].grad, inp[2].grad)
model_configs_fp8 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_4": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"),
"fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"),
"fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
}
param_types_fp8 = [torch.float16, torch.bfloat16]
cudnn_frontend_version = int(os.getenv('NVTE_FUSED_ATTN_FE_VER','1'))
models_v0 = ['fp8_1', 'fp8_2', 'fp8_5', 'fp8_6']
models_v1 = ['fp8_3', 'fp8_4', 'fp8_7', 'fp8_8']
@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
def test_custom_mha_fp8_vs_f16(dtype, model):
"""Test FP8 dot product attention implementations based on cuDNN frontend
v0.9 and v1.0+. Each test compares results from a custom implementation of
an FP8 MHA module, i.e. Custom_MHA_FP8(), to results from an F16 MHA
implementation, i.e. transformer_engine.pytorch.attention.MultiHeadAttention.
Both paths take F16 input and output. QKV layout is t3hd or bs3hd"""
config = model_configs_fp8[model]
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(
dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(
dtype, config, "UnfusedAttention")
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, unfused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(),
unfused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(),
unfused_attn_fwd_f16.min().item())
bwd_rmse = _rmse(fused_attn_bwd_fp8, unfused_attn_bwd_f16)
bwd_range = max(fused_attn_bwd_fp8.max().item(),
unfused_attn_bwd_f16.max().item()) - min(fused_attn_bwd_fp8.min().item(),
unfused_attn_bwd_f16.min().item())
if _NVTE_DEBUG:
print('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item()))
print('unfused_attn_fwd_f16 min {:.6f} max {:.6f}'.format(
unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item()))
print('fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}'.format(
fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, unfused_attn_fwd_f16, **tols)
except Exception as e:
print(e)
print()
print('fused_attn_bwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item()))
print('unfused_attn_bwd_f16 min {:.6f} max {:.6f}'.format(
unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item()))
print('fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}'.format(
bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8, unfused_attn_bwd_f16, **tols)
except Exception as e:
print(e)
print()
assert(fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range)
assert(bwd_rmse < rmse_tol * bwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range)
def _run_custom_mha_fp8(dtype, config, backend):
"""Run Custom_MHA_FP8 with FP8 FusedAttention backend. Both input and output
are in F16. QKV GEMM, DPA, and projection GEMM are calculated in FP8."""
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
......@@ -993,13 +1358,14 @@ def _run_dpa_fp8(dtype, config, backend):
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
inp = 0.01 * torch.randn(
config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim,
inp = 0.0001 * torch.randint(0, 100,
(config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim),
dtype=dtype, device="cuda", requires_grad=True)
seqlens = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
out_grad = 0.01 * torch.randn(
config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim,
dtype=dtype, device="cuda")
......@@ -1013,22 +1379,21 @@ def _run_dpa_fp8(dtype, config, backend):
amax_compute_algo="most_recent",
)
dpa = DPA_FP8(config).to(dtype=torch.float16, device="cuda")
mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
out = dpa(inp, cu_seqlens, config.max_seqlen_q)
out = mha(inp, cu_seqlens, config.max_seqlen_q)
out.backward(out_grad)
context = torch.load("ctx.pt")
out = torch.load("out.pt")
dqkv = torch.load('dqkv.pt')
return (context.view(config.batch_size, config.max_seqlen_q, -1).transpose(0,1),
return (out.view(config.batch_size, config.max_seqlen_q, -1),
dqkv.view(config.batch_size, config.max_seqlen_q, 3,
config.num_heads, config.head_dim).transpose(0,1).contiguous())
config.num_heads, config.head_dim).contiguous())
def _run_dpa_fp8_ref(dtype, config, backend):
"""Run UnfusedDotProductAttention as a reference, i.e.
plain PyTorch implementation in FP16 and inputs/outputs
are converted from/to FP8"""
def _run_ref_mha_f16(dtype, config, backend):
"""Run reference F16 FusedAttention. Both input and output
are in F16. QKV GEMM, DPA, and projection GEMM are also in F16."""
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
......@@ -1043,7 +1408,7 @@ def _run_dpa_fp8_ref(dtype, config, backend):
cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
out_grad = torch.load('out_grad.pt').to(device="cuda").view(
config.batch_size, config.max_seqlen_q, -1).transpose(0,1)
config.batch_size, config.max_seqlen_q, -1)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
......@@ -1069,13 +1434,14 @@ def _run_dpa_fp8_ref(dtype, config, backend):
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type="self"
attention_type="self",
qkv_format="bshd",
).to(dtype=dtype, device="cuda")
)
q = inp[:, :,0,:,:]
k = inp[:, :,1,:,:]
v = inp[:, :,2,:,:]
q = inp[:,:,0,:,:]
k = inp[:,:,1,:,:]
v = inp[:,:,2,:,:]
out = block(q, k, v, attn_mask_type=config.attn_mask_type)
out.backward(out_grad)
......@@ -1088,14 +1454,14 @@ _2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
META_S = tex.FP8FwdTensors.GEMM3_WEIGHT
META_DS = tex.FP8BwdTensors.GRAD_INPUT3
class _dpa_fp8(torch.autograd.Function):
class _custom_mha_fp8(torch.autograd.Function):
@staticmethod
def forward(
ctx,
......@@ -1110,6 +1476,7 @@ class _dpa_fp8(torch.autograd.Function):
fp8_meta: Dict[str, Any],
workspace: torch.Tensor,
is_training: bool,
mask_type: str,
) -> torch.Tensor:
assert inp.dim() == 2
......@@ -1117,14 +1484,10 @@ class _dpa_fp8(torch.autograd.Function):
h = num_heads
d = in_features // h
b = cu_seqlens.numel() - 1
is_nl = False
if b < 4 and b > 1:
max_s = 512
is_nl = True
fp8_dtype_forward = fp8.get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
inputmat, inputmat_t = ext.fp8_cast_transpose_fused(
inp_fp8, inp_t_fp8 = ext.fp8_cast_transpose_fused(
inp,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
......@@ -1142,12 +1505,12 @@ class _dpa_fp8(torch.autograd.Function):
ZInv = None
philox_unpacked = None
qkv_out, _ = ext.fp8_gemm(
qkv, _ = ext.fp8_gemm(
qkv_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
inputmat,
inp_fp8,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
......@@ -1160,26 +1523,29 @@ class _dpa_fp8(torch.autograd.Function):
use_split_accumulator=_2X_ACC_FPROP,
D_dtype=fp8_dtype_forward,
)
qkv_out = qkv_out.view(-1, 3, h, d)
qkv_out_fp16 = ext.cast_from_fp8(qkv_out, fp8_meta["scaling_fwd"],
qkv = qkv.view(-1, 3, h, d)
qkv_fp16 = ext.cast_from_fp8(qkv, fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward,
tex.DType.kFloat16).view(b, max_s, 3, h, d).transpose(0,1).contiguous()
torch.save(qkv_out_fp16, 'qkv.pt')
tex.DType.kFloat16).view(b, max_s, 3, h, d).contiguous()
torch.save(qkv_fp16, 'qkv.pt')
if cudnn_frontend_version == 1:
qkv = qkv.view(b, max_s, 3, h, d) # bs3hd
# FMHA
context_, aux_ctx_tensors, *rest = fused_attn_fwd(
out, aux_ctx_tensors, *rest = fused_attn_fwd(
is_training,
max_s,
max_s,
cu_seqlens,
cu_seqlens,
qkv_out[:,0,:,:],
qkv_out[:,1,:,:],
qkv_out[:,2,:,:],
qkv[:,:,0,:,:] if cudnn_frontend_version == 1 else qkv[:,0,:,:],
qkv[:,:,1,:,:] if cudnn_frontend_version == 1 else qkv[:,1,:,:],
qkv[:,:,2,:,:] if cudnn_frontend_version == 1 else qkv[:,2,:,:],
fp8_dtype_forward,
FusedAttnBackend["FP8"],
None,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
fp8_meta["scaling_fwd"].scale[META_O],
fp8_meta["scaling_fwd"].amax_history[0][META_S],
......@@ -1187,20 +1553,17 @@ class _dpa_fp8(torch.autograd.Function):
attn_scale=None,
dropout=p_dropout,
fast_zero_fill=fast_zero_fill,
qkv_layout="t3hd",
qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
attn_bias_type="no_bias",
attn_mask_type="padding",
attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
rng_gen=None,
)
M, ZInv, philox_unpacked = aux_ctx_tensors
context = context_.view(-1, in_features)
context_t = tex.fp8_transpose(context, fp8_dtype_forward)
M, ZInv, philox_unpacked = aux_ctx_tensors
ctx.save_for_backward(
inputmat_t, qkv_weight_t_fp8, workspace,
qkv_out,
context_, context_t,
inp_t_fp8, qkv_weight_t_fp8, workspace,
qkv, out,
fp8_meta["scaling_fwd"].scale,
fp8_meta["scaling_fwd"].scale_inv,
)
......@@ -1210,14 +1573,16 @@ class _dpa_fp8(torch.autograd.Function):
ctx.p_dropout = p_dropout
ctx.max_s = max_s
ctx.fast_zero_fill = fast_zero_fill
ctx.is_nl = is_nl
ctx.hidden_size = in_features
ctx.num_heads = num_heads
ctx.mask_type = mask_type
ctx.dtype = inp.dtype
context_fp16 = ext.cast_from_fp8(context, fp8_meta["scaling_fwd"],
out = out.view(-1, in_features) # (bs)(hd)
out_fp16 = ext.cast_from_fp8(out, fp8_meta["scaling_fwd"],
META_O, fp8_dtype_forward, tex.DType.kFloat16)
torch.save(context_fp16, 'ctx.pt')
return context_fp16
torch.save(out_fp16, 'out.pt') # (bs)(hd)
return out_fp16
@staticmethod
......@@ -1226,11 +1591,10 @@ class _dpa_fp8(torch.autograd.Function):
) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_DPA"):
(
inputmat_t,
inp_t_fp8,
qkv_weight_t_fp8,
workspace,
qkv_out,
context, context_t,
qkv, out,
fwd_scales,
fwd_scale_inverses,
) = ctx.saved_tensors
......@@ -1243,51 +1607,59 @@ class _dpa_fp8(torch.autograd.Function):
proj_dgrad = ext.cast_to_fp8(
grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
)
) # (bs)(hd)
dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_s,
ctx.max_s,
ctx.cu_seqlens,
ctx.cu_seqlens,
qkv_out[:,0,:,:],
qkv_out[:,1,:,:],
qkv_out[:,2,:,:],
context,
proj_dgrad.view_as(context),
qkv[:,:,0,:,:] if cudnn_frontend_version == 1 else qkv[:,0,:,:],
qkv[:,:,1,:,:] if cudnn_frontend_version == 1 else qkv[:,1,:,:],
qkv[:,:,2,:,:] if cudnn_frontend_version == 1 else qkv[:,2,:,:],
out,
proj_dgrad.view_as(out),
fp8_dtype_forward,
fp8_dtype_backward,
ctx.aux_ctx_tensors,
FusedAttnBackend["FP8"],
fwd_scale_inverses[META_QKV], # d_scale_qkv,
fwd_scale_inverses[META_S], # d_scale_s,
fwd_scale_inverses[META_O], # d_scale_o,
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp
fwd_scales[META_S], # q_scale_s
ctx.fp8_meta['scaling_bwd'].scale[META_DS], # q_scale_ds
ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp
ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DS], # amax_ds
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
None,
ctx.p_dropout,
ctx.fast_zero_fill,
"t3hd",
"no_bias",
"padding",
attn_scale=None,
dropout=ctx.p_dropout,
fast_zero_fill=ctx.fast_zero_fill,
qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
attn_bias_type="no_bias",
attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
)
dqkv = torch.cat([dq.unsqueeze(1), dk.unsqueeze(1), dv.unsqueeze(1)], dim=1)
dqkv_grad_output_c = dqkv.view(-1, 3*ctx.hidden_size)
dqkv_grad_output_c_fp16 = ext.cast_from_fp8(dqkv_grad_output_c,
dim = 2 if cudnn_frontend_version == 1 else 1
dqkv = torch.Tensor().to(device=dq.device, dtype=dq.dtype)
dqkv_shape = list(dq.shape)
dqkv_shape.insert(dim, 3)
dqkv_stride = list(dq.stride())
dqkv_stride.insert(dim, int(dqkv_stride[-3]/3))
dqkv.set_(dq.untyped_storage(), dq.storage_offset(), dqkv_shape, dqkv_stride) # bs3hd
dqkv_c = dqkv.view(-1, 3*ctx.hidden_size)
dqkv_c_fp16 = ext.cast_from_fp8(dqkv_c,
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, tex.DType.kFloat16)
torch.save(dqkv_grad_output_c_fp16, 'dqkv.pt')
torch.save(dqkv_c_fp16, 'dqkv.pt')
qkv_bgrad, dqkv_grad_output_t = ext.fp8_transpose_bgrad_fused(
dqkv_grad_output_c,
qkv_bgrad, dqkv_t = ext.fp8_transpose_bgrad_fused(
dqkv_c,
ctx.fp8_meta["scaling_bwd"],
META_DQKV,
fp8_dtype_backward,
torch.float16,
ctx.dtype,
)
# QKV DGRAD
......@@ -1296,25 +1668,25 @@ class _dpa_fp8(torch.autograd.Function):
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
dqkv_grad_output_c,
dqkv_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
META_DQKV,
fp8_dtype_backward,
torch.float16,
ctx.dtype,
workspace,
use_split_accumulator=_2X_ACC_DGRAD,
)
# QKV WGRAD
qkv_wgrad, _ = ext.fp8_gemm(
inputmat_t,
inp_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
dqkv_grad_output_t,
dqkv_t,
ctx.fp8_meta["scaling_bwd"].scale_inv,
META_DQKV,
fp8_dtype_backward,
torch.float16,
ctx.dtype,
workspace,
use_split_accumulator=_2X_ACC_WGRAD,
)
......@@ -1334,7 +1706,7 @@ class _dpa_fp8(torch.autograd.Function):
None)
class DPA_FP8(TransformerEngineBaseModule):
class Custom_MHA_FP8(TransformerEngineBaseModule):
def __init__(
self,
config,
......@@ -1345,6 +1717,7 @@ class DPA_FP8(TransformerEngineBaseModule):
self.hidden_size = config.hidden_size
self.head_dim = config.head_dim
self.fast_zero_fill = True
self.mask_type = config.attn_mask_type
self.qkv_weight = torch.nn.Parameter(
torch.empty(
......@@ -1374,7 +1747,7 @@ class DPA_FP8(TransformerEngineBaseModule):
cu_seqlens, max_s,
) -> torch.Tensor:
with self.prepare_forward(inp, None, num_gemms=3) as inp:
out = _dpa_fp8.apply(
out = _custom_mha_fp8.apply(
inp,
self.qkv_weight,
self.qkv_bias,
......@@ -1385,7 +1758,8 @@ class DPA_FP8(TransformerEngineBaseModule):
self.fast_zero_fill,
self.fp8_meta,
self.workspace,
self.training)
self.training,
self.mask_type)
return out
def get_fp8_weights_scratchpad(
......
......@@ -1091,7 +1091,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)
# Check output.
atol = {torch.float32 : 2e-4,
atol = {torch.float32 : 2.5e-4,
torch.half : 2e-3,
torch.bfloat16: 2e-2,
}
......
......@@ -85,15 +85,25 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
auto cudnn_runtime_version = cudnnGetVersion();
if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)
&& (sm_arch_ >= 90)
&& (max_seqlen_q == max_seqlen_kv)
&& (num_attn_heads == num_gqa_groups)
&& (max_seqlen_q <= 512)
&& (head_dim == 64)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD)) {
if (((q_dtype == NVTEDType::kNVTEFloat8E4M3)
|| (q_dtype == NVTEDType::kNVTEFloat8E5M2))
&& (sm_arch_ >= 90)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (
((cudnn_runtime_version >= 8900)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD)
&& (max_seqlen_q == max_seqlen_kv)
&& (max_seqlen_q <= 512)
&& (head_dim == 64)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK))
|| ((cudnn_runtime_version >= 90100)
&& (max_seqlen_q % 128 == 0)
&& (max_seqlen_kv % 128 == 0)
&& (head_dim == 128)
&& ((qkv_format == NVTE_QKV_Format::NVTE_BSHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD))
&& ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) {
if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else {
......@@ -269,7 +279,7 @@ void nvte_fused_attn_fwd_qkvpacked(
#if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd_qkvpacked(
b, h, max_seqlen, d,
is_training, attn_scale, dropout, qkv_layout,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_output_S, output_O,
Aux_CTX_Tensors,
input_cu_seqlens,
......@@ -379,7 +389,7 @@ void nvte_fused_attn_bwd_qkvpacked(
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_qkvpacked(
b, h, max_seqlen, d,
attn_scale, dropout, qkv_layout,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_O, input_dO,
input_M, input_ZInv,
input_S, input_output_dP,
......@@ -476,7 +486,18 @@ void nvte_fused_attn_fwd_kvpacked(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
#if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_output_S, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
} else {
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
......@@ -580,7 +601,23 @@ void nvte_fused_attn_bwd_kvpacked(
NVTE_ERROR(err_msg);
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
#if (CUDNN_VERSION >= 8900)
const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_O, input_dO,
input_M, input_ZInv,
input_S, input_output_dP,
output_dQ, output_dKV,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
} else {
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
......@@ -662,8 +699,8 @@ void nvte_fused_attn_fwd(
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd(
b, h_q, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout,
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_output_S, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
......@@ -775,8 +812,8 @@ void nvte_fused_attn_bwd(
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd(
b, h_q, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout,
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_O, input_dO,
input_M, input_ZInv,
input_S, input_output_dP,
......
......@@ -76,7 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
scaling_factor, is_training,
dropout_probability, layout,
bias_type, mask_type,
tensorType};
tensorType, tensorType};
namespace fe = cudnn_frontend;
using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>,
......@@ -147,7 +147,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
fe::graph::SDPA_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_attributes()
.set_name("flash_attention")
.set_is_inference(!is_training)
.set_is_inference(false)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
......@@ -199,11 +199,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride);
if (is_training) {
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1});
}
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1});
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q
std::shared_ptr<fe::graph::Tensor_attributes>, // K
......@@ -211,7 +209,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes> > // O
key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O);
auto Stats_tuple = is_training ? std::make_tuple(Stats) : std::make_tuple(nullptr);
auto Stats_tuple = std::make_tuple(Stats);
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto padding_tuple = is_padding ?
std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
......@@ -258,11 +256,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
{K, devPtrK},
{V, devPtrV},
{attn_scale, &scaling_factor},
{O, devPtrO}};
if (is_training) {
variant_pack[Stats] = devPtrSoftmaxStats;
}
{O, devPtrO},
{Stats, devPtrSoftmaxStats}};
if (is_bias) {
variant_pack[bias] = devPtrBias;
......@@ -321,7 +316,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
scaling_factor, true,
dropout_probability, layout,
bias_type, mask_type,
tensorType};
tensorType, tensorType};
namespace fe = cudnn_frontend;
using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>,
......
......@@ -19,7 +19,7 @@ namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen,
size_t head_size, bool is_training, float attn_scale,
size_t head_dim, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias,
......
......@@ -8,6 +8,7 @@
#include "../common.h"
#include "utils.h"
#include "../util/system.h"
#include "fused_attn_fp8.h"
namespace transformer_engine {
......@@ -984,7 +985,7 @@ static cudnn_frontend::Tensor createdSQBMM(
return After_dSTranspose_Q;
}
// fused attention FWD FP8
// fused attention FWD FP8 with FE 0.9
void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
bool isTraining, float attnScale,
float dropoutProbability, NVTE_QKV_Layout layout,
......@@ -1295,7 +1296,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
}
}
// fused attention BWD FP8
// fused attention BWD FP8 with FE 0.9
void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
float attnScale, float dropoutProbability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrK, void* devPtrV,
......@@ -1846,6 +1847,707 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
}
}
// fused attention FWD FP8 with FE 1.0+
void fused_attn_fp8_fwd_impl_v1(int64_t b, int64_t h, int64_t hg,
int64_t s_q, int64_t s_kv, int64_t d,
bool is_training, float scaling_factor,
float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
void* devPtrQ, void* devPtrK, void* devPtrV,
void* devPtrM, void* devPtrZInv,
void* devPtrO,
void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV,
void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO,
void* devPtrAmaxO, void* devPtrAmaxS,
void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV,
void* devPtrDropoutSeed, void* devPtrDropoutOffset,
cudnn_frontend::DataType_t fwd_tensor_type,
void* workspace,
size_t* workspace_size,
cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (is_training && dropout_probability != 0.0f);
auto bias_b = b;
auto bias_h = h;
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
NVTE_CHECK(~is_padding,
"FP8 fused attention does not support padding/padding_causal mask yet!");
NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!");
try {
FADescriptor_v1 descriptor{b, h,
hg, s_q,
s_kv, d,
bias_b, bias_h,
scaling_factor, is_training,
dropout_probability, layout,
bias_type, mask_type,
fwd_tensor_type, fwd_tensor_type};
namespace fe = cudnn_frontend;
using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>,
std::shared_ptr<fe::graph::Tensor_attributes>, // Q
std::shared_ptr<fe::graph::Tensor_attributes>, // K
std::shared_ptr<fe::graph::Tensor_attributes>, // V
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_q
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_k
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_v
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_s
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_s
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_o
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes>, // O
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_s
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_o
std::shared_ptr<fe::graph::Tensor_attributes>, // Stats
std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // dropout_seed
std::shared_ptr<fe::graph::Tensor_attributes> >; // dropout_offset
using CacheType = std::map<FADescriptor_v1, graph_and_tensors>;
static thread_local CacheType sdpa_fp8_fprop_cache;
// Get plan from cache if cache is available, otherwise create one
auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor)
-> graph_and_tensors {
// if hit, return
auto it = cache.find(descriptor);
if (it != cache.end()) {
auto graph = it->second;
return graph;
}
// otherwise, build the op_graph and the plan. Then update cache
auto mha_graph = std::make_shared<fe::graph::Graph>();
mha_graph->set_io_data_type(fwd_tensor_type)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
std::shared_ptr<fe::graph::Tensor_attributes> Q, K, V, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> descale_q, descale_k, descale_v;
std::shared_ptr<fe::graph::Tensor_attributes> descale_s, scale_s, scale_o;
std::shared_ptr<fe::graph::Tensor_attributes> bias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d})
.set_stride(q_stride));
K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride));
V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride));
attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
descale_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Descale_q")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
descale_k = mha_graph->tensor_like(descale_q, "Descale_q");
descale_v = mha_graph->tensor_like(descale_q, "Descale_V");
descale_s = mha_graph->tensor_like(descale_q, "Descale_S");
scale_s = mha_graph->tensor_like(descale_q, "Scale_S");
scale_o = mha_graph->tensor_like(descale_q, "Scale_O");
fe::graph::SDPA_fp8_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_fp8_attributes()
.set_name("sdpa_fp8")
.set_is_inference(false)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
// sdpa_options.set_alibi_mask(is_alibi);
// if (is_bias) {
// bias = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("bias")
// .set_dim({bias_b, bias_h, s_q, s_kv})
// .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
// sdpa_options.set_bias(bias);
// }
// if (is_padding) {
// seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("seq_q")
// .set_dim({b, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT32));
// seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("seq_kv")
// .set_dim({b, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT32));
// sdpa_options.set_padding_mask(is_padding)
// .set_seq_len_q(seq_q)
// .set_seq_len_kv(seq_kv);
// }
// if (is_dropout) {
// dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("Seed")
// .set_dim({1, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT64));
// dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("Offset")
// .set_dim({1, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT64));
// sdpa_options.set_dropout(
// dropout_probability, dropout_seed, dropout_offset);
// }
auto [O, Stats, amax_s, amax_o] = mha_graph->sdpa_fp8(
Q, K, V, descale_q, descale_k, descale_v, descale_s,
scale_s, scale_o, sdpa_options);
std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride);
amax_o->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT);
amax_s->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT);
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1});
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q
std::shared_ptr<fe::graph::Tensor_attributes>, // K
std::shared_ptr<fe::graph::Tensor_attributes>, // V
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_q
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_k
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_v
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_s
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_s
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_o
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes>, // O
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_s
std::shared_ptr<fe::graph::Tensor_attributes> > // amax_o
key_tensors_tuple = std::make_tuple(Q, K, V, descale_q, descale_k, descale_v,
descale_s, scale_s, scale_o, attn_scale, O, amax_s, amax_o);
auto Stats_tuple = std::make_tuple(Stats);
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto padding_tuple = is_padding ?
std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto dropout_tuple = is_dropout ?
std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr);
NVTE_CHECK_CUDNN_FE(mha_graph->validate());
NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A}));
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
auto return_tuple = std::tuple_cat(
std::make_tuple(mha_graph), key_tensors_tuple,
Stats_tuple, bias_tuple, padding_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
auto [mha_graph, Q, K, V, descale_q, descale_k, descale_v, descale_s,
scale_s, scale_o, attn_scale, O, amax_s, amax_o, Stats,
bias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph(
sdpa_fp8_fprop_cache, descriptor);
auto plan_workspace_size = mha_graph->get_workspace_size();
// Exit to request upper level API to allocate memory if needed
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
if (workspace == nullptr) {
*workspace_size = plan_workspace_size + actual_seqlen_workspace_size;
return;
}
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
// Build variant pack
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{Q, devPtrQ},
{K, devPtrK},
{V, devPtrV},
{descale_q, devPtrDescaleQ},
{descale_k, devPtrDescaleK},
{descale_v, devPtrDescaleV},
{descale_s, devPtrDescaleS},
{scale_s, devPtrScaleS},
{scale_o, devPtrScaleO},
{attn_scale, &scaling_factor},
{O, devPtrO},
{amax_s, devPtrAmaxS},
{amax_o, devPtrAmaxO},
{Stats, devPtrM}};
// if (is_bias) {
// variant_pack[bias] = devPtrBias;
// }
// if (is_padding) {
// constexpr size_t nthreads_per_block = 128;
// const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
// void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
// void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ)
// + b * sizeof(int32_t);
// cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
// b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
// static_cast<const int32_t *>(devPtrCuSeqlensKV),
// static_cast<int32_t *>(devActualSeqlenQ),
// static_cast<int32_t *>(devActualSeqlenKV));
// variant_pack[seq_q] = devActualSeqlenQ;
// variant_pack[seq_kv] = devActualSeqlenKV;
// }
// if (is_dropout) {
// variant_pack[dropout_seed] = devPtrDropoutSeed;
// variant_pack[dropout_offset] = devPtrDropoutOffset;
// }
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
}
}
// fused attention BWD FP8 with FE 1.0+
void fused_attn_fp8_bwd_impl_v1(int64_t b, int64_t h, int64_t hg,
int64_t s_q, int64_t s_kv, int64_t d,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
void* devPtrQ, void* devPtrK, void* devPtrV,
void* devPtrM, void* devPtrZInv,
void* devPtrO, void* devPtrdO,
void* devPtrdQ, void* devPtrdK, void* devPtrdV,
void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV,
void* devPtrDescaleO, void* devPtrDescaledO,
void* devPtrDescaleS, void* devPtrDescaledP,
void* devPtrScaleS, void* devPtrScaledP,
void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV,
void* devPtrAmaxdP,
void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV,
void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV,
void* devPtrDropoutSeed, void* devPtrDropoutOffset,
cudnn_frontend::DataType_t fwd_tensor_type,
cudnn_frontend::DataType_t bwd_tensor_type,
void* workspace,
size_t* workspace_size,
cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f);
auto bias_b = b;
auto bias_h = h;
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
NVTE_CHECK(~is_padding,
"FP8 fused attention does not support padding/padding_causal mask yet!");
NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!");
try {
FADescriptor_v1 descriptor{b, h,
hg, s_q,
s_kv, d,
bias_b, bias_h,
scaling_factor, true,
dropout_probability, layout,
bias_type, mask_type,
fwd_tensor_type, bwd_tensor_type};
namespace fe = cudnn_frontend;
using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>,
std::shared_ptr<fe::graph::Tensor_attributes>, // q
std::shared_ptr<fe::graph::Tensor_attributes>, // k
std::shared_ptr<fe::graph::Tensor_attributes>, // v
std::shared_ptr<fe::graph::Tensor_attributes>, // o
std::shared_ptr<fe::graph::Tensor_attributes>, // stats
std::shared_ptr<fe::graph::Tensor_attributes>, // dO
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_q
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_k
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_v
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_o
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_dO
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_s
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_dP
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_dQ
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_dK
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_dV
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_s
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_dP
std::shared_ptr<fe::graph::Tensor_attributes>, // dQ
std::shared_ptr<fe::graph::Tensor_attributes>, // dK
std::shared_ptr<fe::graph::Tensor_attributes>, // dV
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_dQ
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_dK
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_dV
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_dP
std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // dBias
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // dropout_seed
std::shared_ptr<fe::graph::Tensor_attributes> >; // dropout_offset
using CacheType = std::map<FADescriptor_v1, graph_and_tensors>;
static thread_local CacheType sdpa_fp8_bprop_cache;
// Get plan from cache if cache is available, otherwise create one
auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor)
-> graph_and_tensors {
// if hit, return
auto it = cache.find(descriptor);
if (it != cache.end()) {
auto graph = it->second;
return graph;
}
// otherwise, build the op_graph and the plan. Then update cache
auto mha_graph = std::make_shared<fe::graph::Graph>();
mha_graph->set_io_data_type(fwd_tensor_type)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
std::shared_ptr<fe::graph::Tensor_attributes> q, k, v, o, dO, stats, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> descale_q, descale_k, descale_v;
std::shared_ptr<fe::graph::Tensor_attributes> descale_s, descale_o;
std::shared_ptr<fe::graph::Tensor_attributes> descale_dP, descale_dO;
std::shared_ptr<fe::graph::Tensor_attributes> scale_s, scale_dP;
std::shared_ptr<fe::graph::Tensor_attributes> scale_dQ, scale_dK, scale_dV;
std::shared_ptr<fe::graph::Tensor_attributes> bias, dBias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4);
std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d})
.set_stride(q_stride));
k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride));
v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride));
o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
.set_dim({b, h, s_q, d})
.set_stride(o_stride));
dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO")
.set_dim({b, h, s_q, d})
.set_stride(o_stride));
stats = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("stats")
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
descale_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Descale_q")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
descale_k = mha_graph->tensor_like(descale_q, "Descale_q");
descale_v = mha_graph->tensor_like(descale_q, "Descale_V");
descale_s = mha_graph->tensor_like(descale_q, "Descale_S");
descale_o = mha_graph->tensor_like(descale_q, "Descale_O");
descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP");
descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO");
scale_s = mha_graph->tensor_like(descale_q, "Scale_S");
scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP");
scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ");
scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK");
scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV");
fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options;
sdpa_backward_options = fe::graph::SDPA_fp8_backward_attributes()
.set_name("sdpa_fp8_backward")
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
// sdpa_backward_options.set_alibi_mask(is_alibi);
// if (is_bias) {
// bias = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("bias")
// .set_dim({bias_b, bias_h, s_q, s_kv})
// .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
// dBias = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("dBias")
// .set_dim({bias_b, bias_h, s_q, s_kv})
// .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
// sdpa_backward_options.set_bias(bias);
// // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s]
// // are not supported for dbias calculation but they are
// // supported for forward bias calculation
// if ((bias_b == 1) && (bias_h == h)) {
// sdpa_backward_options.set_dbias(dBias);
// }
// }
// if (is_padding) {
// seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("seq_q")
// .set_dim({b, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT32));
// seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("seq_kv")
// .set_dim({b, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT32));
// sdpa_backward_options.set_padding_mask(is_padding)
// .set_seq_len_q(seq_q)
// .set_seq_len_kv(seq_kv);
// }
// if (is_dropout) {
// dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("Seed")
// .set_dim({1, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT64));
// dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("Offset")
// .set_dim({1, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT64));
// sdpa_backward_options.set_dropout(
// dropout_probability, dropout_seed, dropout_offset);
// }
auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward(
q, k, v, o, dO, stats,
descale_q, descale_k, descale_v,
descale_o, descale_dO, descale_s, descale_dP,
scale_s, scale_dQ, scale_dK, scale_dV, scale_dP,
sdpa_backward_options);
dQ->set_output(true)
.set_dim({b, h, s_q, d})
.set_stride(q_stride);
dK->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride);
dV->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride);
amax_dQ->set_output(true)
.set_dim({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
amax_dK->set_output(true)
.set_dim({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
amax_dV->set_output(true)
.set_dim({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
amax_dP->set_output(true)
.set_dim({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
dO->set_data_type(bwd_tensor_type);
dQ->set_data_type(bwd_tensor_type);
dK->set_data_type(bwd_tensor_type);
dV->set_data_type(bwd_tensor_type);
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // q
std::shared_ptr<fe::graph::Tensor_attributes>, // k
std::shared_ptr<fe::graph::Tensor_attributes>, // v
std::shared_ptr<fe::graph::Tensor_attributes>, // o
std::shared_ptr<fe::graph::Tensor_attributes>, // stats
std::shared_ptr<fe::graph::Tensor_attributes>, // dO
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_q
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_k
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_v
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_o
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_dO
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_s
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_dP
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_dQ
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_dK
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_dV
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_s
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_dP
std::shared_ptr<fe::graph::Tensor_attributes>, // dQ
std::shared_ptr<fe::graph::Tensor_attributes>, // dK
std::shared_ptr<fe::graph::Tensor_attributes>, // dV
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_dQ
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_dK
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_dV
std::shared_ptr<fe::graph::Tensor_attributes> > // amax_dP
key_tensors_tuple = std::make_tuple(
q, k, v, o, stats, dO, attn_scale,
descale_q, descale_k, descale_v,
descale_o, descale_dO, descale_s, descale_dP,
scale_s, scale_dQ, scale_dK, scale_dV, scale_dP,
dQ, dK, dV,
amax_dQ, amax_dK, amax_dV, amax_dP);
auto bias_tuple = is_bias ?
std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr);
auto padding_tuple = is_padding ?
std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto dropout_tuple = is_dropout ?
std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr);
NVTE_CHECK_CUDNN_FE(mha_graph->validate());
NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A}));
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
auto return_tuple = std::tuple_cat(
std::make_tuple(mha_graph), key_tensors_tuple,
bias_tuple, padding_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
auto [mha_graph, q, k, v, o, stats, dO, attn_scale,
descale_q, descale_k, descale_v,
descale_o, descale_dO, descale_s, descale_dP,
scale_s, scale_dQ, scale_dK, scale_dV, scale_dP,
dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP,
bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph(
sdpa_fp8_bprop_cache, descriptor);
auto plan_workspace_size = mha_graph->get_workspace_size();
// Exit to request upper level API to allocate memory if needed
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
if (workspace == nullptr) {
*workspace_size = plan_workspace_size + actual_seqlen_workspace_size;
return;
}
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
// build variant pack
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{q, devPtrQ},
{k, devPtrK},
{v, devPtrV},
{o, devPtrO},
{stats, devPtrM},
{dO, devPtrdO},
{attn_scale, &scaling_factor},
{descale_q, devPtrDescaleQ},
{descale_k, devPtrDescaleK},
{descale_v, devPtrDescaleV},
{descale_o, devPtrDescaleO},
{descale_dO, devPtrDescaledO},
{descale_s, devPtrDescaleS},
{descale_dP, devPtrDescaledP},
{scale_s, devPtrScaleS},
{scale_dQ, devPtrScaledQ},
{scale_dK, devPtrScaledK},
{scale_dV, devPtrScaledV},
{scale_dP, devPtrScaledP},
{dQ, devPtrdQ},
{dK, devPtrdK},
{dV, devPtrdV},
{amax_dQ, devPtrAmaxdQ},
{amax_dK, devPtrAmaxdK},
{amax_dV, devPtrAmaxdV},
{amax_dP, devPtrAmaxdP},
};
// if (is_bias) {
// variant_pack[bias] = devPtrBias;
// if ((bias_b == 1) && (bias_h == h)) {
// variant_pack[dBias] = devPtrdBias;
// } else {
// variant_pack[dBias] = nullptr;
// }
// }
// if (is_padding) {
// constexpr size_t nthreads_per_block = 128;
// const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
// void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
// void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ)
// + b * sizeof(int32_t);
// cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
// b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
// static_cast<const int32_t *>(devPtrCuSeqlensKV),
// static_cast<int32_t *>(devActualSeqlenQ),
// static_cast<int32_t *>(devActualSeqlenKV));
// variant_pack[seq_q] = devActualSeqlenQ;
// variant_pack[seq_kv] = devActualSeqlenKV;
// }
// if (is_dropout) {
// variant_pack[dropout_seed] = devPtrDropoutSeed;
// variant_pack[dropout_offset] = devPtrDropoutOffset;
// }
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
}
}
#endif
} // namespace fused_attn
......@@ -1853,9 +2555,10 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
#if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV
void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t h, size_t max_seqlen, size_t d,
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV,
Tensor *input_output_S,
Tensor *output_O,
......@@ -1866,11 +2569,18 @@ void fused_attn_fp8_fwd_qkvpacked(
cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
// QKV shape is [total_seqs, 3, h, d]
const DType QKV_type = input_QKV->data.dtype;
void* devPtrQKV = input_QKV->data.dptr;
void* devPtrQ = reinterpret_cast<void *>(devPtrQKV);
void* devPtrK = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrQKV) + h * d);
void* devPtrV = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrQKV) + 2 * h * d);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = typeToSize(QKV_type) * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = typeToSize(QKV_type) * head_dim;
}
void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void* devPtrDescaleQ = input_QKV->scale_inv.dptr;
void* devPtrDescaleK = input_QKV->scale_inv.dptr;
void* devPtrDescaleV = input_QKV->scale_inv.dptr;
......@@ -1882,21 +2592,19 @@ void fused_attn_fp8_fwd_qkvpacked(
void* devPtrM = nullptr;
void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) {
if (is_training) {
Aux_CTX_Tensors->size = 3;
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr;
output_M->data.shape = {b, h, max_seqlen, 1};
output_M->data.dtype = DType::kFloat32;
output_ZInv->data.dptr = nullptr;
output_ZInv->data.shape = {b, h, max_seqlen, 1};
output_ZInv->data.dtype = DType::kFloat32;
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
}
Aux_CTX_Tensors->size = 3;
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr;
output_M->data.shape = {batch, num_attn_heads, max_seqlen, 1};
output_M->data.dtype = DType::kFloat32;
output_ZInv->data.dptr = nullptr;
output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen, 1};
output_ZInv->data.dtype = DType::kFloat32;
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
......@@ -1919,11 +2627,27 @@ void fused_attn_fp8_fwd_qkvpacked(
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_fwd_impl_v1(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleS, devPtrScaleS, devPtrScaleO,
devPtrAmaxO, devPtrAmaxS,
devPtrcuSeqlens, devPtrcuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_fwd_impl(
b, h, max_seqlen, max_seqlen, d,
batch, num_attn_heads, max_seqlen, max_seqlen, head_dim,
is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
......@@ -1935,6 +2659,9 @@ void fused_attn_fp8_fwd_qkvpacked(
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else {
NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n");
}
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1950,8 +2677,9 @@ void fused_attn_fp8_fwd_qkvpacked(
}
// fused attention BWD FP8 with packed QKV
void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t h, size_t max_seqlen, size_t d,
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV,
const Tensor *input_O,
const Tensor *input_dO,
......@@ -1966,11 +2694,19 @@ void fused_attn_fp8_bwd_qkvpacked(
cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
// QKV shape is [total_seqs, 3, h, d]
const DType QKV_type = input_QKV->data.dtype;
const DType dQKV_type = output_dQKV->data.dtype;
void* devPtrQKV = input_QKV->data.dptr;
void* devPtrQ = reinterpret_cast<void *>(devPtrQKV);
void* devPtrK = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrQKV) + h * d);
void* devPtrV = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrQKV) + 2 * h * d);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = typeToSize(QKV_type) * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = typeToSize(QKV_type) * head_dim;
}
void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void* devPtrDescaleQ = input_QKV->scale_inv.dptr;
void* devPtrDescaleK = input_QKV->scale_inv.dptr;
void* devPtrDescaleV = input_QKV->scale_inv.dptr;
......@@ -1985,15 +2721,14 @@ void fused_attn_fp8_bwd_qkvpacked(
void* devPtrScaleS = input_S->scale.dptr;
void* devPtrDescaleS = input_S->scale_inv.dptr;
void* devPtrAmaxdS = input_output_dP->amax.dptr;
void* devPtrScaledS = input_output_dP->scale.dptr;
void* devPtrDescaledS = input_output_dP->scale_inv.dptr;
// dQKV shape is [total_seqs, 3, h, d]
void* devPtrdQKV = output_dQKV->data.dptr;
void* devPtrdQ = reinterpret_cast<void *>(devPtrdQKV);
void* devPtrdK = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrdQKV) + h * d);
void* devPtrdV = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrdQKV) + 2 * h * d);
void* devPtrAmaxdP = input_output_dP->amax.dptr;
void* devPtrScaledP = input_output_dP->scale.dptr;
void* devPtrDescaledP = input_output_dP->scale_inv.dptr;
void *devPtrdQKV = output_dQKV->data.dptr;
void *devPtrdQ = devPtrdQKV;
void *devPtrdK = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + stride);
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + 2 * stride);
void* devPtrAmaxdQ = output_dQKV->amax.dptr;
void* devPtrAmaxdK = output_dQKV->amax.dptr;
void* devPtrAmaxdV = output_dQKV->amax.dptr;
......@@ -2008,11 +2743,33 @@ void fused_attn_fp8_bwd_qkvpacked(
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_bwd_impl_v1(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO, devPtrdO,
devPtrdQ, devPtrdK, devPtrdV,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleO, devPtrDescaledO,
devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledP,
devPtrScaledQ, devPtrScaledK, devPtrScaledV,
devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV,
devPtrcuSeqlens, devPtrcuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(dQKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_bwd_impl(
b, h, max_seqlen, max_seqlen, d,
batch, num_attn_heads, max_seqlen, max_seqlen, head_dim,
attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
......@@ -2020,15 +2777,278 @@ void fused_attn_fp8_bwd_qkvpacked(
devPtrdQ, devPtrdK, devPtrdV,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleO, devPtrDescaledO,
devPtrDescaleS, devPtrDescaledS,
devPtrScaleS, devPtrScaledS,
devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledP,
devPtrScaledQ, devPtrScaledK, devPtrScaledV,
devPtrAmaxdS,
devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV,
devPtrcuSeqlens, devPtrcuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else {
NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n");
}
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = { workspace_size };
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = { 1 };
workspace->data.dtype = DType::kByte;
return;
}
}
// fused attention FWD FP8 with packed KV
void fused_attn_fp8_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q,
const Tensor *input_KV,
Tensor *input_output_S,
Tensor *output_O,
NVTETensorPack* Aux_CTX_Tensors,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
void* devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = typeToSize(QKV_type) * head_dim;
}
void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
void* devPtrDescaleQ = input_Q->scale_inv.dptr;
void* devPtrDescaleK = input_KV->scale_inv.dptr;
void* devPtrDescaleV = input_KV->scale_inv.dptr;
void* devPtrO = output_O->data.dptr;
void* devPtrAmaxO = output_O->amax.dptr;
void* devPtrScaleO = output_O->scale.dptr;
void* devPtrM = nullptr;
void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 3;
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr;
output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_M->data.dtype = DType::kFloat32;
output_ZInv->data.dptr = nullptr;
output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_ZInv->data.dtype = DType::kFloat32;
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void* devPtrAmaxS = input_output_S->amax.dptr;
void* devPtrScaleS = input_output_S->scale.dptr;
void* devPtrDescaleS = input_output_S->scale_inv.dptr;
void* devPtrcuSeqlensQ = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens_q->data.dptr));
void* devPtrcuSeqlensKV = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens_kv->data.dptr));
void* devPtrDropoutSeed = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr));
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_fwd_impl_v1(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleS, devPtrScaleS, devPtrScaleO,
devPtrAmaxO, devPtrAmaxS,
devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_fwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim,
is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleS, devPtrScaleS, devPtrScaleO,
devPtrAmaxO, devPtrAmaxS,
devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else {
NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n");
}
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = { workspace_size };
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = { 1 };
workspace->data.dtype = DType::kByte;
return;
}
}
// fused attention BWD FP8 with packed KV
void fused_attn_fp8_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q,
const Tensor *input_KV,
const Tensor *input_O,
const Tensor *input_dO,
const Tensor *input_M,
const Tensor *input_ZInv,
const Tensor *input_S,
Tensor *input_output_dP,
const Tensor *output_dQ,
const Tensor *output_dKV,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
const DType dQKV_type = output_dQ->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = typeToSize(QKV_type) * head_dim;
}
void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
void* devPtrDescaleQ = input_Q->scale_inv.dptr;
void* devPtrDescaleK = input_KV->scale_inv.dptr;
void* devPtrDescaleV = input_KV->scale_inv.dptr;
void* devPtrO = input_O->data.dptr;
void* devPtrDescaleO = input_O->scale_inv.dptr;
void* devPtrdO = input_dO->data.dptr;
void* devPtrDescaledO = input_dO->scale_inv.dptr;
void* devPtrM = input_M->data.dptr;
void* devPtrZInv = input_ZInv->data.dptr;
void* devPtrScaleS = input_S->scale.dptr;
void* devPtrDescaleS = input_S->scale_inv.dptr;
void* devPtrAmaxdP = input_output_dP->amax.dptr;
void* devPtrScaledP = input_output_dP->scale.dptr;
void* devPtrDescaledP = input_output_dP->scale_inv.dptr;
void *devPtrdQ = output_dQ->data.dptr;
void *devPtrdKV = output_dKV->data.dptr;
void *devPtrdK = devPtrdKV;
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdKV) + stride);
void* devPtrAmaxdQ = output_dQ->amax.dptr;
void* devPtrAmaxdK = output_dKV->amax.dptr;
void* devPtrAmaxdV = output_dKV->amax.dptr;
void* devPtrScaledQ = output_dQ->scale.dptr;
void* devPtrScaledK = output_dKV->scale.dptr;
void* devPtrScaledV = output_dKV->scale.dptr;
void* devPtrcuSeqlensQ = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens_q->data.dptr));
void* devPtrcuSeqlensKV = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens_kv->data.dptr));
void* devPtrDropoutSeed = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr));
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_bwd_impl_v1(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO, devPtrdO,
devPtrdQ, devPtrdK, devPtrdV,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleO, devPtrDescaledO,
devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledP,
devPtrScaledQ, devPtrScaledK, devPtrScaledV,
devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV,
devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(dQKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_bwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim,
attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO, devPtrdO,
devPtrdQ, devPtrdK, devPtrdV,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleO, devPtrDescaledO,
devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledP,
devPtrScaledQ, devPtrScaledK, devPtrScaledV,
devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV,
devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else {
NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n");
}
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -2044,9 +3064,11 @@ void fused_attn_fp8_bwd_qkvpacked(
}
// fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd(
size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q,
const Tensor *input_K,
const Tensor *input_V,
......@@ -2074,21 +3096,19 @@ void fused_attn_fp8_fwd(
void* devPtrM = nullptr;
void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) {
if (is_training) {
Aux_CTX_Tensors->size = 3;
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr;
output_M->data.shape = {b, h, max_seqlen_q, 1};
output_M->data.dtype = DType::kFloat32;
output_ZInv->data.dptr = nullptr;
output_ZInv->data.shape = {b, h, max_seqlen_q, 1};
output_ZInv->data.dtype = DType::kFloat32;
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
}
Aux_CTX_Tensors->size = 3;
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr;
output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_M->data.dtype = DType::kFloat32;
output_ZInv->data.dptr = nullptr;
output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_ZInv->data.dtype = DType::kFloat32;
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
......@@ -2116,8 +3136,25 @@ void fused_attn_fp8_fwd(
const DType QKV_type = input_Q->data.dtype;
size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_fwd_impl_v1(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleS, devPtrScaleS, devPtrScaleO,
devPtrAmaxO, devPtrAmaxS,
devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_fwd_impl(
b, h, max_seqlen_q, max_seqlen_kv, d,
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim,
is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
......@@ -2129,6 +3166,9 @@ void fused_attn_fp8_fwd(
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else {
NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n");
}
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -2144,8 +3184,10 @@ void fused_attn_fp8_fwd(
}
// fused attention BWD FP8 with separate Q, K, V
void fused_attn_fp8_bwd(
size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q,
const Tensor *input_K,
const Tensor *input_V,
......@@ -2182,9 +3224,9 @@ void fused_attn_fp8_bwd(
void* devPtrScaleS = input_S->scale.dptr;
void* devPtrDescaleS = input_S->scale_inv.dptr;
void* devPtrAmaxdS = input_output_dP->amax.dptr;
void* devPtrScaledS = input_output_dP->scale.dptr;
void* devPtrDescaledS = input_output_dP->scale_inv.dptr;
void* devPtrAmaxdP = input_output_dP->amax.dptr;
void* devPtrScaledP = input_output_dP->scale.dptr;
void* devPtrDescaledP = input_output_dP->scale_inv.dptr;
void* devPtrdQ = output_dQ->data.dptr;
void* devPtrdK = output_dK->data.dptr;
......@@ -2206,10 +3248,34 @@ void fused_attn_fp8_bwd(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_Q->data.dtype;
const DType dQKV_type = output_dQ->data.dtype;
size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_bwd_impl_v1(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO, devPtrdO,
devPtrdQ, devPtrdK, devPtrdV,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleO, devPtrDescaledO,
devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledP,
devPtrScaledQ, devPtrScaledK, devPtrScaledV,
devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV,
devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(dQKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_bwd_impl(
b, h, max_seqlen_q, max_seqlen_kv, d,
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim,
attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
......@@ -2217,15 +3283,18 @@ void fused_attn_fp8_bwd(
devPtrdQ, devPtrdK, devPtrdV,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleO, devPtrDescaledO,
devPtrDescaleS, devPtrDescaledS,
devPtrScaleS, devPtrScaledS,
devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledP,
devPtrScaledQ, devPtrScaledK, devPtrScaledV,
devPtrAmaxdS,
devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV,
devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else {
NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n");
}
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......
......@@ -14,9 +14,10 @@ namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV
void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t h, size_t max_seqlen, size_t d,
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV,
Tensor *input_output_S,
Tensor *output_O,
......@@ -29,8 +30,9 @@ void fused_attn_fp8_fwd_qkvpacked(
// fused attention BWD FP8 with packed QKV
void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t h, size_t max_seqlen, size_t d,
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV,
const Tensor *input_O,
const Tensor *input_dO,
......@@ -45,11 +47,55 @@ void fused_attn_fp8_bwd_qkvpacked(
cudaStream_t stream,
cudnnHandle_t handle);
// fused attention FWD FP8 with packed KV
void fused_attn_fp8_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q,
const Tensor *input_KV,
Tensor *input_output_S,
Tensor *output_O,
NVTETensorPack* Aux_CTX_Tensors,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle);
// fused attention BWD FP8 with packed KV
void fused_attn_fp8_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q,
const Tensor *input_KV,
const Tensor *input_O,
const Tensor *input_dO,
const Tensor *input_M,
const Tensor *input_ZInv,
const Tensor *input_S,
Tensor *input_output_dP,
const Tensor *output_dQ,
const Tensor *output_dKV,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle);
// fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd(
size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
Tensor *input_output_S,
Tensor *output_O,
......@@ -63,8 +109,10 @@ void fused_attn_fp8_fwd(
// fused attention BWD FP8 with separate Q, K, V
void fused_attn_fp8_bwd(
size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_O,
const Tensor *input_dO,
......
......@@ -111,19 +111,20 @@ struct FADescriptor_v1 {
NVTE_QKV_Layout layout;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
cudnn_frontend::DataType_t tensor_type;
cudnn_frontend::DataType_t fwd_tensor_type;
cudnn_frontend::DataType_t bwd_tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h,
attnScale, isTraining, dropoutProbability,
layout, mask_type, bias_type, tensor_type)
layout, mask_type, bias_type, fwd_tensor_type, bwd_tensor_type)
< std::tie(
rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d,
rhs.bias_b, rhs.bias_h,
rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.bias_type,
rhs.tensor_type);
rhs.fwd_tensor_type, rhs.bwd_tensor_type);
}
};
......
......@@ -96,7 +96,7 @@ class DelayedScaling:
where `Tensor` is a framework tensor type.
override_linear_precision: Tuple(bool, bool, bool), default=(False, False, False)
Whether or not the execute the `fprop`, `dgrad`, and `wgrad`
Whether or not to execute the `fprop`, `dgrad`, and `wgrad`
GEMMs (respectively) in higher precision when using FP8.
reduce_amax: bool, default = `True`
By default, if `torch.distributed` is initialized, the `amax` value for FP8
......@@ -106,6 +106,20 @@ class DelayedScaling:
GPU maintains local amaxes and scaling factors. To ensure results are
numerically identical across checkpointing boundaries in this case, all
ranks must checkpoint in order to store the local tensors.
fp8_dpa: bool, default = `False`
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend.
fp8_mha: bool, default = `False`
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
operations mentioned above at the DPA boundaries. Currently only standard MHA modules
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
`fp8_mha = False, fp8_dpa = True`, a typical MHA module works as
`LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`.
When `fp8_mha = True, fp8_dpa = True`, it becomes
`LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`.
Notes
-----
......@@ -116,6 +130,9 @@ class DelayedScaling:
FP8_MAX = maximum_representable_value(fp8_format)
new_scaling_factor = (FP8_MAX / amax) / (2 ^ margin)
* `fp8_dpa` and `fp8_mha` are Beta features, and their API and functionality are
subject to change in future Transformer Engine releases.
"""
margin: int = 0
......@@ -126,6 +143,8 @@ class DelayedScaling:
override_linear_precision: _OverrideLinearPrecision = _OverrideLinearPrecision()
scaling_factor_compute_algo: Optional[Callable] = None
reduce_amax: bool = True
fp8_dpa: bool = False
fp8_mha: bool = False
def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
......
......@@ -19,6 +19,10 @@ import torch
import torch.nn.functional as F
import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import (
cast_to_fp8,
cast_from_fp8,
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked,
......@@ -31,7 +35,10 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
AttnMaskType,
FusedAttnBackend,
)
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module import LayerNormLinear, Linear
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.utils import (
divide,
attention_mask_func,
......@@ -74,6 +81,12 @@ if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward # pylint: disable=no-name-in-module,ungrouped-imports
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
_alibi_cache = {
......@@ -810,7 +823,7 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k,
cu_seqlens_q, cu_seqlens_k,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], TE_DType[kv.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
......@@ -850,7 +863,7 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k//2,
cu_seqlens_q, cu_seqlens_k//2,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], TE_DType[kv.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
......@@ -890,7 +903,7 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q//2, ctx.max_seqlen_k,
cu_seqlens_q//2, cu_seqlens_k,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], TE_DType[kv.dtype],
[softmax_lse_, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
......@@ -923,7 +936,7 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k,
cu_seqlens_q, cu_seqlens_k,
q, kv[0], kv[1], out, dout, TE_DType[q.dtype],
q, kv[0], kv[1], out, dout, TE_DType[q.dtype], TE_DType[kv.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
......@@ -1246,6 +1259,14 @@ class _SplitAlongDim(torch.autograd.Function):
) -> Tuple[torch.Tensor, ...]:
ctx.split_dim = split_dim
ctx.split_size_or_sections = split_size_or_sections
if isinstance(mixed_x_layer, Float8Tensor):
return tuple(Float8Tensor.make_like(
mixed_x_layer,
data=x,
) for x in torch.split(
mixed_x_layer._data,
split_size_or_sections=split_size_or_sections,
dim=split_dim))
return torch.split(mixed_x_layer, split_size_or_sections, dim = split_dim)
@staticmethod
......@@ -1262,6 +1283,37 @@ class _SplitAlongDim(torch.autograd.Function):
dims = len(grad_outputs[0].shape)
split_dim = (ctx.split_dim + dims) % dims
if isinstance(grad_outputs[0], Float8Tensor):
noop_ok = True
strides = grad_outputs[0].stride()
data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
shape = list(grad_outputs[0].shape)
for i, tensor in enumerate(grad_outputs):
shape_i = shape
shape_i[split_dim] = split_sizes[i]
offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim+1:])
if (tensor.stride() != strides or
list(tensor.shape) != shape_i or
tensor._data.untyped_storage().data_ptr() != data_ptr or
tensor.storage_offset() != offset_size):
noop_ok = False
break
if noop_ok:
ret = torch.Tensor().to(device=grad_outputs[0].device,
dtype=grad_outputs[0]._data.dtype)
new_shape = list(shape)
new_shape[split_dim] = sum(split_sizes)
ret.set_(grad_outputs[0]._data.untyped_storage(),
grad_outputs[0]._data.storage_offset(),
new_shape,
strides
)
return Float8Tensor.make_like(grad_outputs[0], data=ret), None, None
grad_outputs_data = [x._data for x in grad_outputs]
return Float8Tensor.make_like(
grad_outputs[0],
data=torch.cat(grad_outputs_data, dim = split_dim)), None, None
noop_ok = True
strides = grad_outputs[0].stride()
data_ptr = grad_outputs[0].untyped_storage().data_ptr()
......@@ -1276,7 +1328,6 @@ class _SplitAlongDim(torch.autograd.Function):
tensor.storage_offset() != offset_size):
noop_ok = False
break
if noop_ok:
ret = torch.Tensor().to(device=grad_outputs[0].device,
dtype=grad_outputs[0].dtype)
......@@ -1848,6 +1899,35 @@ class FlashAttention(torch.nn.Module):
return output
def _combine_tensors(
tensors: List[torch.Tensor],
dim: int,
) -> torch.Tensor:
"""Combine tensors along a particular dimension"""
num_tensors = len(tensors)
new_shape = list(tensors[0].shape)
new_shape.insert(dim, num_tensors)
new_stride = list(tensors[0].stride())
new_stride.insert(dim, int(new_stride[dim-1]/num_tensors))
if isinstance(tensors[0], Float8Tensor):
combined_tensor = torch.Tensor().to(
device=tensors[0].device, dtype=tensors[0]._data.dtype)
combined_tensor.set_(
tensors[0]._data.untyped_storage(),
tensors[0]._data.storage_offset(),
new_shape, new_stride)
combined_tensor = Float8Tensor.make_like(
tensors[0], data=combined_tensor)
else:
combined_tensor = torch.Tensor().to(
device=tensors[0].device, dtype=tensors[0].dtype)
combined_tensor.set_(
tensors[0].untyped_storage(),
tensors[0].storage_offset(),
new_shape, new_stride)
return combined_tensor
class FusedAttnFunc_qkvpacked(torch.autograd.Function):
"""Function for FusedAttention with packed QKV input"""
......@@ -1855,15 +1935,83 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
@staticmethod
def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale,
dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen, fused_attention_backend, use_FAv2_bwd):
out, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype,
fused_attention_backend, attn_bias,
None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen)
ctx.save_for_backward(qkv, out, cu_seqlens)
rng_gen, fused_attention_backend, use_FAv2_bwd,
fp8, fp8_meta, tp_size, tp_group):
if fp8:
if _NVTE_DEBUG:
print('[DotProductAttention]: using FP8 forward')
if fp8_meta["recipe"].fp8_mha:
assert (isinstance(qkv, Float8Tensor)), "qkv must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
fused_attention_backend = FusedAttnBackend["FP8"]
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_group = len(qkv_layout.split('_'))
assert (qkv_group == 1
), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, \
but found {qkv_layout}."
if fp8_meta["recipe"].fp8_mha:
qkv_fp8 = qkv._data
else:
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
qkv_fp8 = cast_to_fp8(qkv_c,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(qkv.shape)
out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
is_training, max_seqlen, cu_seqlens,
qkv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
fp8_meta["scaling_fwd"].scale[META_O],
fp8_meta["scaling_fwd"].amax_history[0][META_S],
fp8_meta["scaling_fwd"].amax_history[0][META_O],
attn_scale, dropout_p, fast_zero_fill, qkv_layout,
attn_bias_type, attn_mask_type, rng_gen)
if fp8_meta["recipe"].fp8_mha:
out_ret = Float8Tensor(data=out_fp8,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=META_O,
fp8_dtype=fp8_dtype_forward,
dtype=qkv.dtype,
)
else:
out_ret = cast_from_fp8(
out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
fp8_meta["scaling_fwd"], META_O,
fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
out_save = out_ret
if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
qkv = cast_from_fp8(qkv_c._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[qkv.dtype]).view(qkv.shape)
out_save = cast_from_fp8(
out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
fp8_meta["scaling_fwd"], META_O,
fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
fp8_tensors = (qkv_fp8, out_fp8,
fp8_meta["scaling_fwd"].scale.clone(),
fp8_meta["scaling_fwd"].scale_inv.clone())
else:
if _NVTE_DEBUG:
print('[DotProductAttention]: using non-FP8 forward')
out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype,
fused_attention_backend, attn_bias,
None, None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen)
fp8_tensors = (None, None, None, None)
out_save = out_ret
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens, *fp8_tensors)
ctx.fp8_meta = fp8_meta
ctx.tp_size = tp_size
ctx.tp_group = tp_group
ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen = max_seqlen
ctx.qkv_dtype = qkv_dtype
......@@ -1873,15 +2021,23 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.fused_attention_backend = fused_attention_backend
ctx.fused_attention_backend = \
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
ctx.use_FAv2_bwd = use_FAv2_bwd
return out
return out_ret
@staticmethod
def backward(ctx, d_out):
if ctx.fp8_meta["recipe"].fp8_mha:
assert (isinstance(d_out, Float8Tensor)
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
d_out_f8tensor = d_out
d_out = d_out._data
d_out = d_out.contiguous()
qkv, out, cu_seqlens = ctx.saved_tensors
(qkv, out, cu_seqlens,
qkv_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd:
......@@ -1898,13 +2054,65 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
)
dqkv = dqkv[..., :d_out.shape[-1]]
else:
dqkv, *rest = fused_attn_bwd_qkvpacked(
ctx.max_seqlen, cu_seqlens, qkv, out, d_out,
ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"):
if ctx.fp8:
if _NVTE_DEBUG:
print('[DotProductAttention]: using FP8 backward')
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False)
if ctx.fp8_meta["recipe"].fp8_mha:
d_out_fp8 = d_out
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv
else:
d_out_fp8 = cast_to_fp8(
d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
).view(d_out.shape)
dqkv_fp8, *rest = fused_attn_bwd_qkvpacked(
ctx.max_seqlen, cu_seqlens,
qkv_fp8, out_fp8, d_out_fp8,
fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o,
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp
fwd_scales[META_S], # q_scale_s
ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp
ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
if ctx.fp8_meta["recipe"].fp8_mha:
dqkv = Float8Tensor(data=dqkv_fp8,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward,
dtype=d_out_f8tensor.dtype,
)
else:
dqkv_c_fp8 = dqkv_fp8.view(-1,
dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1])
dqkv = cast_from_fp8(dqkv_c_fp8,
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape)
else:
if _NVTE_DEBUG:
print('[DotProductAttention]: using non-FP8 backward')
if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(qkv.dtype)
dqkv, *rest = fused_attn_bwd_qkvpacked(
ctx.max_seqlen, cu_seqlens, qkv, out, d_out,
ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]:
......@@ -1923,16 +2131,90 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
@staticmethod
def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
qkv_layout, attn_bias_type, attn_mask_type,
rng_gen, fused_attention_backend, use_FAv2_bwd):
out, aux_ctx_tensors = fused_attn_fwd_kvpacked(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, qkv_dtype, fused_attention_backend, attn_bias,
None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen)
ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv)
qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
use_FAv2_bwd, fp8, fp8_meta, tp_size, tp_group):
if fp8:
if _NVTE_DEBUG:
print('[DotProductAttention]: using FP8 forward')
if fp8_meta["recipe"].fp8_mha:
assert (isinstance(q, Float8Tensor)
and isinstance(kv, Float8Tensor)), "q/kv must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
fused_attention_backend = FusedAttnBackend["FP8"]
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if fp8_meta["recipe"].fp8_mha:
q_fp8, kv_fp8 = q._data, kv._data
else:
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_group = len(qkv_layout.split('_'))
assert (qkv_group == 2
), f"qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, \
but found {qkv_layout}."
q_fp8 = cast_to_fp8(q,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(q.shape)
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv_fp8 = cast_to_fp8(kv_c,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(kv.shape)
out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q_fp8, kv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
fp8_meta["scaling_fwd"].scale[META_O],
fp8_meta["scaling_fwd"].amax_history[0][META_S],
fp8_meta["scaling_fwd"].amax_history[0][META_O],
attn_scale, dropout_p, fast_zero_fill, qkv_layout,
attn_bias_type, attn_mask_type, rng_gen)
if fp8_meta["recipe"].fp8_mha:
out_ret = Float8Tensor(data=out_fp8,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=META_O,
fp8_dtype=fp8_dtype_forward,
dtype=q.dtype,
)
else:
out_ret = cast_from_fp8(
out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
fp8_meta["scaling_fwd"], META_O,
fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
out_save = out_ret
if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q = cast_from_fp8(q._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape)
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv = cast_from_fp8(kv_c._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[kv.dtype]).view(kv.shape)
out_save = cast_from_fp8(
out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
fp8_meta["scaling_fwd"], META_O,
fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
fp8_tensors = (q_fp8, kv_fp8, out_fp8,
fp8_meta["scaling_fwd"].scale.clone(),
fp8_meta["scaling_fwd"].scale_inv.clone())
else:
if _NVTE_DEBUG:
print('[DotProductAttention]: using non-FP8 forward')
out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, qkv_dtype, fused_attention_backend, attn_bias,
None, None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen)
out_save = out_ret
fp8_tensors = (None, None, None, None, None)
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors)
ctx.fp8_meta = fp8_meta
ctx.tp_size = tp_size
ctx.tp_group = tp_group
ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv
......@@ -1943,15 +2225,23 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.fused_attention_backend = fused_attention_backend
ctx.fused_attention_backend = \
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
ctx.use_FAv2_bwd = use_FAv2_bwd
return out
return out_ret
@staticmethod
def backward(ctx, d_out):
if ctx.fp8_meta["recipe"].fp8_mha:
assert (isinstance(d_out, Float8Tensor)
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
d_out_f8tensor = d_out
d_out = d_out._data
d_out = d_out.contiguous()
q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
(q, kv, out, cu_seqlens_q, cu_seqlens_kv,
q_fp8, kv_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd:
......@@ -1970,14 +2260,77 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
dq = dq[..., :d_out.shape[-1]]
dkv = dkv[..., :d_out.shape[-1]]
else:
dq, dkv, *rest = fused_attn_bwd_kvpacked(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, out, d_out,
ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
with torch.cuda.nvtx.range("_FusedAttn_kvpacked"):
if ctx.fp8:
if _NVTE_DEBUG:
print('[DotProductAttention]: using FP8 backward')
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False)
if ctx.fp8_meta["recipe"].fp8_mha:
d_out_fp8 = d_out
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv
else:
d_out_fp8 = cast_to_fp8(
d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
).view(d_out.shape)
dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q_fp8, kv_fp8, out_fp8, d_out_fp8,
fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o,
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp
fwd_scales[META_S], # q_scale_s
ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp
ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
if ctx.fp8_meta["recipe"].fp8_mha:
dq = Float8Tensor(data=dq_fp8,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward,
dtype=d_out_f8tensor.dtype,
)
dkv = Float8Tensor(data=dkv_fp8,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward,
dtype=d_out_f8tensor.dtype,
)
else:
dq = cast_from_fp8(
dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape)
dkv_c_fp8 = dkv_fp8.view(-1,
dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1])
dkv = cast_from_fp8(dkv_c_fp8,
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape)
else:
if _NVTE_DEBUG:
print('[DotProductAttention]: using non-FP8 backward')
if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(q.dtype)
dq, dkv, *rest = fused_attn_bwd_kvpacked(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, out, d_out,
ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]:
......@@ -1989,32 +2342,153 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
None, None, None, None, None, None,
None, None, None, None, None, None)
class FusedAttnFunc(torch.autograd.Function):
"""Function for FusedAttention with separate Q, K, V tensors"""
@staticmethod
def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
qkv_layout, attn_bias_type, attn_mask_type,
rng_gen, fused_attention_backend, use_FAv2_bwd):
out, aux_ctx_tensors = fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, k, v, qkv_dtype, fused_attention_backend, attn_bias,
None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen)
qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
use_FAv2_bwd, fp8, fp8_meta, tp_size, tp_group):
if fp8:
if _NVTE_DEBUG:
print('[DotProductAttention]: using FP8 forward')
fused_attention_backend = FusedAttnBackend["FP8"]
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if fp8_meta["recipe"].fp8_mha:
assert (isinstance(q, Float8Tensor)
and isinstance(k, Float8Tensor)
and isinstance(v, Float8Tensor)), "q/k/v must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data
else:
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_group = len(qkv_layout.split('_'))
if qkv_group == 1:
dim = qkv_layout.find('3')
qkv = _combine_tensors([q,k,v], dim)
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
qkv_fp8 = cast_to_fp8(qkv_c,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(qkv.shape)
q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1,1,1])
q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]]
if qkv_group == 2:
q_fp8 = cast_to_fp8(q,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(q.shape)
dim = qkv_layout.split('_')[1].find('2')
kv = _combine_tensors([k,v], dim)
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv_fp8 = cast_to_fp8(kv_c,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(kv.shape)
k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1,1])
k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]]
if qkv_group == 3:
q_fp8 = cast_to_fp8(q,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(q.shape)
k_fp8 = cast_to_fp8(k,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(k.shape)
v_fp8 = cast_to_fp8(v,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(v.shape)
out_fp8, aux_ctx_tensors = fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q_fp8, k_fp8, v_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
fp8_meta["scaling_fwd"].scale[META_O],
fp8_meta["scaling_fwd"].amax_history[0][META_S],
fp8_meta["scaling_fwd"].amax_history[0][META_O],
attn_scale, dropout_p, fast_zero_fill, qkv_layout,
attn_bias_type, attn_mask_type, rng_gen)
if fp8_meta["recipe"].fp8_mha:
out_ret = Float8Tensor(data=out_fp8,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=META_O,
fp8_dtype=fp8_dtype_forward,
dtype=q.dtype,
)
else:
out_ret = cast_from_fp8(
out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
fp8_meta["scaling_fwd"], META_O,
fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
out_save = out_ret
if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_group = len(qkv_layout.split('_'))
if qkv_group == 1:
dim = qkv_layout.find('3')
qkv = _combine_tensors([q,k,v], dim)
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
qkv_no_fp8 = cast_from_fp8(qkv_c._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[qkv.dtype]).view(qkv.shape)
q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1,1,1])
q, k, v = [x.squeeze(dim) for x in [q, k, v]]
if qkv_group == 2:
q = cast_from_fp8(q._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape)
dim = qkv_layout.split('_')[1].find('2')
kv = _combine_tensors([k,v], dim)
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv_no_fp8 = cast_from_fp8(kv_c._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[kv.dtype]).view(kv.shape)
k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1,1])
k, v = [x.squeeze(dim) for x in [k, v]]
if qkv_group == 3:
q = cast_from_fp8(q._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape)
k = cast_from_fp8(k._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[k.dtype]).view(k.shape)
v = cast_from_fp8(v._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[v.dtype]).view(v.shape)
out_save = cast_from_fp8(
out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
fp8_meta["scaling_fwd"], META_O,
fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8,
fp8_meta["scaling_fwd"].scale.clone(),
fp8_meta["scaling_fwd"].scale_inv.clone())
else:
if _NVTE_DEBUG:
print('[DotProductAttention]: using non-FP8 forward')
out_ret, aux_ctx_tensors = fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, k, v, qkv_dtype, fused_attention_backend, attn_bias,
None, None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen)
out_save = out_ret
fp8_tensors = (None, None, None, None, None, None)
from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
tensor_list = [q, k, v, out, cu_seqlens_q, cu_seqlens_kv]
tensor_list = [q, k, v, out_save, cu_seqlens_q, cu_seqlens_kv]
qkv_layout = 'sbhd_sbhd_sbhd'
for tensor in tensor_list:
if tensor is not None:
tensor.activation_offloading = True
ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv)
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors)
ctx.fp8_meta = fp8_meta
ctx.tp_size = tp_size
ctx.tp_group = tp_group
ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv
......@@ -2025,15 +2499,23 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.fused_attention_backend = fused_attention_backend
ctx.fused_attention_backend = \
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
ctx.use_FAv2_bwd = use_FAv2_bwd
return out
return out_ret
@staticmethod
def backward(ctx, d_out):
if ctx.fp8_meta["recipe"].fp8_mha:
assert (isinstance(d_out, Float8Tensor)
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
d_out_f8tensor = d_out
d_out = d_out._data
d_out = d_out.contiguous()
q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors
(q, k, v, out, cu_seqlens_q, cu_seqlens_kv,
q_fp8, k_fp8, v_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd:
......@@ -2054,14 +2536,112 @@ class FusedAttnFunc(torch.autograd.Function):
dk = dk[..., :d_out.shape[-1]]
dv = dv[..., :d_out.shape[-1]]
else:
dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, k, v, out, d_out,
ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
with torch.cuda.nvtx.range("_FusedAttn"):
if ctx.fp8:
if _NVTE_DEBUG:
print('[DotProductAttention]: using FP8 backward')
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False)
if ctx.fp8_meta["recipe"].fp8_mha:
d_out_fp8 = d_out
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv
else:
d_out_fp8 = cast_to_fp8(
d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
).view(d_out.shape)
dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q_fp8, k_fp8, v_fp8, out_fp8, d_out_fp8,
fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o,
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp
fwd_scales[META_S], # q_scale_s
ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp
ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
if ctx.fp8_meta["recipe"].fp8_mha:
dq = Float8Tensor(data=dq_fp8,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward,
dtype=d_out_f8tensor.dtype,
)
dk = Float8Tensor(data=dk_fp8,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward,
dtype=d_out_f8tensor.dtype,
)
dv = Float8Tensor(data=dv_fp8,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward,
dtype=d_out_f8tensor.dtype,
)
else:
qkv_group = len(ctx.qkv_layout.split('_'))
if qkv_group == 1:
dim = ctx.qkv_layout.find('3')
dqkv_fp8 = _combine_tensors([dq_fp8,dk_fp8,dv_fp8], dim)
dqkv_c_fp8 = dqkv_fp8.view(-1,
dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1])
dqkv = cast_from_fp8(dqkv_c_fp8,
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape)
dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1,1,1])
dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]]
if qkv_group == 2:
dq = cast_from_fp8(
dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape)
dim = ctx.qkv_layout.split('_')[1].find('2')
dkv_fp8 = _combine_tensors([dk_fp8,dv_fp8], dim)
dkv_c_fp8 = dkv_fp8.view(-1,
dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1])
dkv = cast_from_fp8(dkv_c_fp8,
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape)
dk, dv = _SplitAlongDim.apply(dkv, dim, [1,1])
dk, dv = [x.squeeze(dim) for x in [dk, dv]]
if qkv_group == 3:
dq = cast_from_fp8(
dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape)
dk = cast_from_fp8(
dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]),
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dk_fp8.shape)
dv = cast_from_fp8(
dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]),
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dv_fp8.shape)
else:
if _NVTE_DEBUG:
print('[DotProductAttention]: using non-FP8 backward')
if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(q.dtype)
dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, k, v, out, d_out,
ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]:
......@@ -2074,7 +2654,7 @@ class FusedAttnFunc(torch.autograd.Function):
None, None, None, None, None, None)
class FusedAttention(torch.nn.Module):
class FusedAttention(TransformerEngineBaseModule):
"""Dot product attention, with multiple backends:
1. FusedAttnBackend["F16_max512_seqlen"]
......@@ -2110,6 +2690,8 @@ class FusedAttention(torch.nn.Module):
attention_type: str = "self",
layer_number: Optional[int] = None,
deterministic: bool = False,
tp_size: int = 1,
tp_group: Optional[dist_group_type] = None,
) -> None:
super().__init__()
......@@ -2136,6 +2718,15 @@ class FusedAttention(torch.nn.Module):
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
self.tp_size = tp_size
self.tp_group = tp_group
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[Float8Tensor]:
"""Needs override."""
@no_torch_dynamo()
def forward(
self,
......@@ -2157,6 +2748,7 @@ class FusedAttention(torch.nn.Module):
cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
is_first_microbatch: Optional[bool] = None,
) -> torch.Tensor:
"""fused attention fprop"""
......@@ -2164,9 +2756,9 @@ class FusedAttention(torch.nn.Module):
!= tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
), 'No fused attention backend supports this input combination!'
assert (
(query_layer.dtype in [torch.float16, torch.bfloat16])
and (key_layer.dtype in [torch.float16, torch.bfloat16])
and (value_layer.dtype in [torch.float16, torch.bfloat16])
(query_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
and (key_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
and (value_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
), 'FusedAttention only supports FP16 and BF16 data types.'
assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
......@@ -2248,24 +2840,43 @@ class FusedAttention(torch.nn.Module):
if qkv_format == 'sbhd':
output = output.transpose(0,1).contiguous()
else:
with self.attention_dropout_ctx():
output = FusedAttnFunc.apply(
self.training,
max_seqlen_q, max_seqlen_kv,
cu_seqlens_q, cu_seqlens_kv,
query_layer, key_layer, value_layer,
qkv_dtype,
core_attention_bias,
1.0/self.norm_factor,
self.attention_dropout if self.training else 0.0,
fast_zero_fill,
qkv_layout,
core_attention_bias_type,
attn_mask_type,
None, # rng_gen
fused_attention_backend,
use_FAv2_bwd,
)
with self.prepare_forward(query_layer,
is_first_microbatch,
num_gemms=3,
allow_non_contiguous=True) as query_layer:
with self.attention_dropout_ctx():
forced_fp8_dpa = ""
if self.fp8_meta["recipe"].fp8_mha:
if not self.fp8_meta["recipe"].fp8_dpa:
self.fp8_meta["recipe"].fp8_dpa = True
forced_fp8_dpa = " (forced)"
if _NVTE_DEBUG:
print("[DotProductAttention]: "
f"""using fp8_recipe.fp8_mha={self.fp8_meta["recipe"].fp8_mha}, """
f"""fp8_recipe.fp8_dpa={self.fp8_meta["recipe"].fp8_dpa}"""
f"""{forced_fp8_dpa} and """
f"""NVTE_FP8_DPA_BWD={int(os.getenv("NVTE_FP8_DPA_BWD", "1"))}""")
output = FusedAttnFunc.apply(
self.training,
max_seqlen_q, max_seqlen_kv,
cu_seqlens_q, cu_seqlens_kv,
query_layer, key_layer, value_layer,
qkv_dtype,
core_attention_bias,
1.0/self.norm_factor,
self.attention_dropout if self.training else 0.0,
fast_zero_fill,
qkv_layout,
core_attention_bias_type,
attn_mask_type,
None, # rng_gen
fused_attention_backend,
use_FAv2_bwd,
self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
self.fp8_meta,
self.tp_size,
self.tp_group,
)
# ...hd -> ...(hd)
return output.view(*output.shape[:-2], -1)
......@@ -2463,7 +3074,9 @@ class DotProductAttention(torch.nn.Module):
attention_type=attention_type,
layer_number=layer_number,
deterministic=self.deterministic,
**attn_kwargs)
**attn_kwargs,
tp_size=self.tp_size,
tp_group=self.tp_group)
self.unfused_attention = UnfusedDotProductAttention(
norm_factor, **attn_kwargs, layer_number=layer_number)
......@@ -2532,6 +3145,7 @@ class DotProductAttention(torch.nn.Module):
alibi_slopes: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
inference_params: Optional[InferenceParams] = None,
is_first_microbatch: Optional[bool] = None,
) -> torch.Tensor:
"""
Dot Product Attention Layer.
......@@ -2635,6 +3249,19 @@ class DotProductAttention(torch.nn.Module):
Adjustments of the sequence_len_offset should be done after a complete forward pass.
If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
"""
assert (
......@@ -2746,8 +3373,14 @@ class DotProductAttention(torch.nn.Module):
), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
the sequence dimention in 'key_layer' and 'value_layer'!"""
qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout(
query_layer, key_layer, value_layer, qkv_format = qkv_format)
if (isinstance(query_layer, Float8Tensor)
and isinstance(key_layer, Float8Tensor)
and isinstance(value_layer, Float8Tensor)):
qkv_layout, query_layer._data, key_layer._data, value_layer._data = _get_qkv_layout(
query_layer._data, key_layer._data, value_layer._data, qkv_format = qkv_format)
else:
qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout(
query_layer, key_layer, value_layer, qkv_format = qkv_format)
# The priority for attention backends (subject to availability and clearing the filters)
# is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention.
......@@ -2767,8 +3400,13 @@ class DotProductAttention(torch.nn.Module):
if (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
or any(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer])
):
use_flash_attention = False
if (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
):
use_fused_attention = False
# Filter: Device and dimensions.
......@@ -2865,8 +3503,10 @@ class DotProductAttention(torch.nn.Module):
if use_fused_attention:
fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype],
TE_DType[key_layer.dtype],
TE_DType[query_layer.dtype]
if not isinstance(query_layer, Float8Tensor) else query_layer._fp8_dtype,
TE_DType[key_layer.dtype]
if not isinstance(key_layer, Float8Tensor) else key_layer._fp8_dtype,
QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type],
......@@ -2879,7 +3519,9 @@ class DotProductAttention(torch.nn.Module):
)
# DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = (fused_attention_backend in
[FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]])
[FusedAttnBackend["F16_max512_seqlen"],
FusedAttnBackend["F16_arbitrary_seqlen"],
FusedAttnBackend["FP8"]])
use_fused_attention = ( \
use_fused_attention and is_backend_avail and \
(not context_parallel or \
......@@ -2950,6 +3592,8 @@ class DotProductAttention(torch.nn.Module):
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
fused_attention_backend=fused_attention_backend,
......@@ -2959,8 +3603,7 @@ class DotProductAttention(torch.nn.Module):
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv)
is_first_microbatch=is_first_microbatch)
return self.fused_attention(
query_layer,
key_layer,
......@@ -2968,6 +3611,8 @@ class DotProductAttention(torch.nn.Module):
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
fused_attention_backend=fused_attention_backend,
......@@ -2977,8 +3622,7 @@ class DotProductAttention(torch.nn.Module):
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv)
is_first_microbatch=is_first_microbatch)
assert (not context_parallel), \
"Context parallelism is only implemented with Flash Attention and Fused Attention!"
......@@ -3552,6 +4196,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_x_layer = self.qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=True, # specific to FP8 MHA
)
num_queries_per_key_value = (self.num_attention_heads_per_partition //
......@@ -3603,6 +4248,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_kv_layer = self.key_value(
encoder_output,
is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=True, # specific to FP8 MHA
)
if self.qkv_weight_interleaved:
......@@ -3633,6 +4279,9 @@ class MultiheadAttention(torch.nn.Module):
key_layer, value_layer = torch.split(
mixed_kv_layer, mixed_kv_layer.shape[split_dim] // 2, dim = split_dim,
)
key_layer, value_layer = (x.reshape(
x.size(0), x.size(1), -1, self.hidden_size_per_attention_head,
) for x in (key_layer, value_layer))
# Attention head [sq, b, h] --> [sq, b, hp]
if self.input_layernorm:
......@@ -3648,6 +4297,7 @@ class MultiheadAttention(torch.nn.Module):
query_layer = self.query_layer(
hidden_states,
is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=True, # specific to FP8 MHA
)
# [sq, b, hp] --> [sq, b, np, hn]
......@@ -3662,6 +4312,9 @@ class MultiheadAttention(torch.nn.Module):
# ======================================================
if rotary_pos_emb is not None:
assert (not isinstance(query_layer, Float8Tensor)
and not isinstance(key_layer, Float8Tensor)
), "RoPE is not supported for Float8Tensors!"
# duplicate the pos_emb for self attention
if not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = ((rotary_pos_emb,) * 2)
......
......@@ -84,6 +84,7 @@ def fused_attn_fwd_qkvpacked(
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None,
amax_s: torch.Tensor = None,
......@@ -119,6 +120,8 @@ def fused_attn_fwd_qkvpacked(
shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_o: torch.Tensor, default = None
......@@ -206,6 +209,8 @@ def fused_attn_fwd_qkvpacked(
assert (d_scale_qkv is not None
), "d_scale_qkv is required as an input for FP8 fused attention."
assert (d_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_o is not None
......@@ -220,7 +225,7 @@ def fused_attn_fwd_qkvpacked(
max_seqlen, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, qkv_dtype,
d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias,
d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias,
rng_gen, rng_elts_per_thread,
)
......@@ -235,12 +240,14 @@ def fused_attn_bwd_qkvpacked(
o: torch.Tensor,
d_o: torch.Tensor,
qkv_dtype: tex.DType,
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
d_scale_do: torch.Tensor = None,
d_scale_dp: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_dp: torch.Tensor = None,
q_scale_dqkv: torch.Tensor = None,
......@@ -272,6 +279,8 @@ def fused_attn_bwd_qkvpacked(
same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype
dqkv_dtype: tex.DType
data type of dQKV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
......@@ -285,6 +294,8 @@ def fused_attn_bwd_qkvpacked(
input tensor for the dequantization of O in FP8 computations
d_scale_do: torch.Tensor, default = None
input tensor for the dequantization of dO in FP8 computations
d_scale_dp: torch.Tensor, default = None
input tensor for the dequantization of dP in FP8 computations
q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations
q_scale_dp: torch.Tensor, default = None
......@@ -336,6 +347,7 @@ def fused_attn_bwd_qkvpacked(
assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention."
assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention."
assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention."
assert (d_scale_dp is not None), "d_scale_dp is required for FP8 fused attention."
assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention."
assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention."
assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention."
......@@ -348,8 +360,8 @@ def fused_attn_bwd_qkvpacked(
output_tensors = tex.fused_attn_bwd_qkvpacked(
max_seqlen, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, o, d_o, qkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
cu_seqlens, qkv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
)
......@@ -368,6 +380,7 @@ def fused_attn_fwd_kvpacked(
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None,
amax_s: torch.Tensor = None,
......@@ -410,6 +423,8 @@ def fused_attn_fwd_kvpacked(
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_o: torch.Tensor, default = None
......@@ -496,12 +511,25 @@ def fused_attn_fwd_kvpacked(
rng_elts_per_thread = (max_seqlen_q * max_seqlen_q
+ BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA
assert (d_scale_qkv is not None
), "d_scale_qkv is required as an input for FP8 fused attention."
assert (d_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_o is not None
), "q_scale_o is required as an input for FP8 fused attention."
assert (amax_s is not None
), "amax_s is required as an input for FP8 fused attention."
assert (amax_o is not None
), "amax_o is required as an input for FP8 fused attention."
# execute kernel
output_tensors = tex.fused_attn_fwd_kvpacked(
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype,
d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o,
d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o,
attn_bias, rng_gen, rng_elts_per_thread,
)
......@@ -519,12 +547,14 @@ def fused_attn_bwd_kvpacked(
o: torch.Tensor,
d_o: torch.Tensor,
qkv_dtype: tex.DType,
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
d_scale_do: torch.Tensor = None,
d_scale_dp: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_dp: torch.Tensor = None,
q_scale_dqkv: torch.Tensor = None,
......@@ -562,7 +592,9 @@ def fused_attn_bwd_kvpacked(
input tensor dO (gradient of O);
same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype
data type of Q and KV; in tex.DType, not torch.dtype
dqkv_dtype: tex.DType
data type of dQ and dKV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
......@@ -576,6 +608,8 @@ def fused_attn_bwd_kvpacked(
input tensor for the dequantization of O in FP8 computations
d_scale_do: torch.Tensor, default = None
input tensor for the dequantization of dO in FP8 computations
d_scale_dp: torch.Tensor, default = None
input tensor for the dequantization of dP in FP8 computations
q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations
q_scale_dp: torch.Tensor, default = None
......@@ -631,6 +665,7 @@ def fused_attn_bwd_kvpacked(
assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention."
assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention."
assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention."
assert (d_scale_dp is not None), "d_scale_dp is required for FP8 fused attention."
assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention."
assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention."
assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention."
......@@ -643,8 +678,8 @@ def fused_attn_bwd_kvpacked(
output_tensors = tex.fused_attn_bwd_kvpacked(
max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
)
......@@ -664,6 +699,7 @@ def fused_attn_fwd(
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None,
amax_s: torch.Tensor = None,
......@@ -710,6 +746,8 @@ def fused_attn_fwd(
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None
input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_o: torch.Tensor, default = None
......@@ -798,12 +836,25 @@ def fused_attn_fwd(
rng_elts_per_thread = (max_seqlen_q * max_seqlen_q
+ BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA
assert (d_scale_qkv is not None
), "d_scale_qkv is required as an input for FP8 fused attention."
assert (d_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_o is not None
), "q_scale_o is required as an input for FP8 fused attention."
assert (amax_s is not None
), "amax_s is required as an input for FP8 fused attention."
assert (amax_o is not None
), "amax_o is required as an input for FP8 fused attention."
# execute kernel
output_tensors = tex.fused_attn_fwd(
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype,
d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o,
d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o,
attn_bias, rng_gen, rng_elts_per_thread,
)
......@@ -822,12 +873,14 @@ def fused_attn_bwd(
o: torch.Tensor,
d_o: torch.Tensor,
qkv_dtype: tex.DType,
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
d_scale_do: torch.Tensor = None,
d_scale_dp: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
q_scale_dp: torch.Tensor = None,
q_scale_dqkv: torch.Tensor = None,
......@@ -869,6 +922,8 @@ def fused_attn_bwd(
same shape as Q
qkv_dtype: tex.DType
data type of Q, K and V; in tex.DType, not torch.dtype
dqkv_dtype: tex.DType
data type of dQ, dK and dV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
......@@ -882,6 +937,8 @@ def fused_attn_bwd(
input tensor for the dequantization of O in FP8 computations
d_scale_do: torch.Tensor, default = None
input tensor for the dequantization of dO in FP8 computations
d_scale_dp: torch.Tensor, default = None
input tensor for the dequantization of dP in FP8 computations
q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations
q_scale_dp: torch.Tensor, default = None
......@@ -941,6 +998,7 @@ def fused_attn_bwd(
assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention."
assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention."
assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention."
assert (d_scale_dp is not None), "d_scale_dp is required for FP8 fused attention."
assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention."
assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention."
assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention."
......@@ -953,8 +1011,8 @@ def fused_attn_bwd(
output_tensors = tex.fused_attn_bwd(
max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
)
......
......@@ -786,9 +786,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const bool do_gelu = pre_gelu_out.numel() > 0;
const int output_chunk_bytes = (do_gelu
? (n_chunk * m) * D.element_size()
: (n_chunk * m) * HALF_BYTES);
const int output_chunk_bytes = (n_chunk * m) * D.element_size();
const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0;
// Get output and workspace data pointers
......
......@@ -32,6 +32,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
const at::Tensor QKV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
......@@ -51,11 +52,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
......@@ -74,6 +77,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
const at::Tensor KV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
......@@ -95,11 +99,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
......@@ -119,6 +125,7 @@ std::vector<at::Tensor> fused_attn_fwd(
const at::Tensor V,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
......@@ -141,11 +148,13 @@ std::vector<at::Tensor> fused_attn_bwd(
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
......
......@@ -97,6 +97,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
const at::Tensor QKV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
......@@ -126,22 +127,24 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
// FP8
auto h = q_shape[q_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0)) {
if (set_zero
&& ((h * d) % block_size == 0)
&& (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
O.fill_(0);
}
if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value())
|| (!amax_S.has_value()) || (!amax_O.has_value())) {
std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O";
if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!scale_S.has_value()) || (!scale_O.has_value())
|| (!amax_S.has_value()) || (!amax_O.has_value())) {
std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O ";
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
at::Tensor descale_S = torch::empty_like(scale_S.value());
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_S.value().data_ptr(),
scale_S.value().data_ptr(), descale_S.data_ptr());
scale_S.value().data_ptr(), descale_S.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
......@@ -261,11 +264,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
......@@ -284,26 +289,29 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
auto h = q_shape[q_shape.size() - 2];
// create output tensor dQKV
at::Tensor dQKV = torch::empty_like(QKV);
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA);
at::Tensor dQKV = torch::empty_like(QKV, options);
// construct NVTE tensors
TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0)) {
if (set_zero
&& ((h * d) % block_size == 0)
&& (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
dQKV.fill_(0);
}
if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!descale_O.has_value()) || (!descale_dO.has_value())
|| (!scale_S.has_value()) || (!scale_dP.has_value())
|| (!scale_dQKV.has_value())
|| (!amax_dP.has_value()) || (!amax_dQKV.has_value())) {
std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, ";
err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV");
|| (!descale_O.has_value()) || (!descale_dO.has_value())
|| (!descale_dP.has_value()) || (!scale_S.has_value())
|| (!scale_dP.has_value()) || (!scale_dQKV.has_value())
|| (!amax_dP.has_value()) || (!amax_dQKV.has_value())) {
std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, ";
err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, ");
err_tensors = err_tensors + std::string("amax_dP and amax_dQKV ");
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape,
......@@ -311,14 +319,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_O.value().data_ptr());
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
dqkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32,
nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr());
at::Tensor descale_dP = torch::empty_like(scale_dP.value());
te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_dP.value().data_ptr(), scale_dP.value().data_ptr(),
descale_dP.data_ptr());
te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, qkv_type,
DType::kFloat32, amax_dP.value().data_ptr(),
scale_dP.value().data_ptr(), descale_dP.value().data_ptr());
te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, dqkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16
......@@ -327,13 +334,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
dqkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape,
qkv_type, nullptr, nullptr, nullptr);
dqkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
......@@ -433,6 +440,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
const at::Tensor KV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
......@@ -458,24 +466,26 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
// FP8
auto h = q_shape[q_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0)) {
if (set_zero
&& ((h * d) % block_size == 0)
&& (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
O.fill_(0);
}
if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value())
|| (!amax_S.has_value()) || (!amax_O.has_value())) {
std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O";
if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!scale_S.has_value()) || (!scale_O.has_value())
|| (!amax_S.has_value()) || (!amax_O.has_value())) {
std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O ";
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
at::Tensor descale_S = torch::empty_like(scale_S.value());
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_S.value().data_ptr(),
scale_S.value().data_ptr(), descale_S.data_ptr());
scale_S.value().data_ptr(), descale_S.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
......@@ -608,11 +618,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
......@@ -635,15 +647,18 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
auto d = q_shape[q_shape.size() - 1];
// create output tensors dQ and dKV
at::Tensor dQ = torch::empty_like(Q);
at::Tensor dKV = torch::empty_like(KV);
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA);
at::Tensor dQ = torch::empty_like(Q, options);
at::Tensor dKV = torch::empty_like(KV, options);
// construct NVTE tensors
TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
if (set_zero && ((h_q * d)% block_size == 0) && ((h_kv * d)% block_size == 0)) {
if (set_zero
&& ((h_q * d)% block_size == 0)
&& ((h_kv * d)% block_size == 0)
&& (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
......@@ -651,12 +666,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
dKV.fill_(0);
}
if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!descale_O.has_value()) || (!descale_dO.has_value())
|| (!scale_S.has_value()) || (!scale_dP.has_value())
|| (!scale_dQKV.has_value())
|| (!amax_dP.has_value()) || (!amax_dQKV.has_value())) {
std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, ";
err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV");
|| (!descale_O.has_value()) || (!descale_dO.has_value())
|| (!descale_dP.has_value()) || (!scale_S.has_value())
|| (!scale_dP.has_value()) || (!scale_dQKV.has_value())
|| (!amax_dP.has_value()) || (!amax_dQKV.has_value())) {
std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, ";
err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, ");
err_tensors = err_tensors + std::string("amax_dP and amax_dQKV ");
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
......@@ -666,16 +682,15 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_O.value().data_ptr());
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
dqkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr,
scale_S.value().data_ptr(), descale_S.value().data_ptr());
at::Tensor descale_dP = torch::empty_like(scale_dP.value());
te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32,
amax_dP.value().data_ptr(), scale_dP.value().data_ptr(),
descale_dP.data_ptr());
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, qkv_type,
descale_dP.value().data_ptr());
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, qkv_type,
te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, dqkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16
......@@ -686,15 +701,15 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
dqkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
dqkv_type, nullptr, nullptr, nullptr);
te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape,
qkv_type, nullptr, nullptr, nullptr);
dqkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
......@@ -806,6 +821,7 @@ std::vector<at::Tensor> fused_attn_fwd(
const at::Tensor V,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
......@@ -832,14 +848,17 @@ std::vector<at::Tensor> fused_attn_fwd(
// FP8
auto h = q_shape[q_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0)) {
if (set_zero
&& ((h * d) % block_size == 0)
&& (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
O.fill_(0);
}
if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value())
|| (!amax_S.has_value()) || (!amax_O.has_value())) {
std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O";
if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!scale_S.has_value()) || (!scale_O.has_value())
|| (!amax_S.has_value()) || (!amax_O.has_value())) {
std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O ";
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
......@@ -848,10 +867,9 @@ std::vector<at::Tensor> fused_attn_fwd(
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
at::Tensor descale_S = torch::empty_like(scale_S.value());
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_S.value().data_ptr(),
scale_S.value().data_ptr(), descale_S.data_ptr());
scale_S.value().data_ptr(), descale_S.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
......@@ -990,11 +1008,13 @@ std::vector<at::Tensor> fused_attn_bwd(
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
......@@ -1011,7 +1031,7 @@ std::vector<at::Tensor> fused_attn_bwd(
auto h_q = q_shape[q_shape.size() - 2];
auto h_kv = k_shape[k_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA);
at::Tensor dQ;
at::Tensor dK;
......@@ -1046,7 +1066,7 @@ std::vector<at::Tensor> fused_attn_bwd(
torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2);
break;
case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
dQ = torch::empty_like(Q);
dQ = torch::empty_like(Q, options);
tmp_shape = std::vector<int64_t>{k_sizes.begin(), k_sizes.end()};
tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2));
dKV = torch::empty(c10::IntArrayRef(tmp_shape), options);
......@@ -1058,7 +1078,7 @@ std::vector<at::Tensor> fused_attn_bwd(
torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3);
break;
case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
dQ = torch::empty_like(Q);
dQ = torch::empty_like(Q, options);
tmp_shape = std::vector<int64_t>{k_sizes.begin(), k_sizes.end()};
tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2));
dKV = torch::empty(c10::IntArrayRef(tmp_shape), options);
......@@ -1068,9 +1088,9 @@ std::vector<at::Tensor> fused_attn_bwd(
torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2);
break;
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
dQ = torch::empty_like(Q);
dK = torch::empty_like(K);
dV = torch::empty_like(V);
dQ = torch::empty_like(Q, options);
dK = torch::empty_like(K, options);
dV = torch::empty_like(V, options);
break;
default:
NVTE_ERROR("QKV layout not supported!");
......@@ -1085,7 +1105,8 @@ std::vector<at::Tensor> fused_attn_bwd(
&& ((h_kv * d) % block_size == 0)
&& dQ.is_contiguous()
&& dK.is_contiguous()
&& dV.is_contiguous()) {
&& dV.is_contiguous()
&& (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
......@@ -1095,12 +1116,13 @@ std::vector<at::Tensor> fused_attn_bwd(
dV.fill_(0);
}
if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!descale_O.has_value()) || (!descale_dO.has_value())
|| (!scale_S.has_value()) || (!scale_dP.has_value())
|| (!scale_dQKV.has_value())
|| (!amax_dP.has_value()) || (!amax_dQKV.has_value())) {
std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, ";
err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV");
|| (!descale_O.has_value()) || (!descale_dO.has_value())
|| (!descale_dP.has_value()) || (!scale_S.has_value())
|| (!scale_dP.has_value()) || (!scale_dQKV.has_value())
|| (!amax_dP.has_value()) || (!amax_dQKV.has_value())) {
std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, ";
err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, ");
err_tensors = err_tensors + std::string("amax_dP and amax_dQKV ");
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
}
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
......@@ -1112,18 +1134,17 @@ std::vector<at::Tensor> fused_attn_bwd(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_O.value().data_ptr());
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
dqkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr,
scale_S.value().data_ptr(), descale_S.value().data_ptr());
at::Tensor descale_dP = torch::empty_like(scale_dP.value());
te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32,
amax_dP.value().data_ptr(), scale_dP.value().data_ptr(),
descale_dP.data_ptr());
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, qkv_type,
descale_dP.value().data_ptr());
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
te_dK = makeTransformerEngineTensor(dK.data_ptr(), k_shape, qkv_type,
te_dK = makeTransformerEngineTensor(dK.data_ptr(), k_shape, dqkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
te_dV = makeTransformerEngineTensor(dV.data_ptr(), v_shape, qkv_type,
te_dV = makeTransformerEngineTensor(dV.data_ptr(), v_shape, dqkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16
......@@ -1136,17 +1157,17 @@ std::vector<at::Tensor> fused_attn_bwd(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
dqkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr);
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr);
dqkv_type, nullptr, nullptr, nullptr);
te_dK = makeTransformerEngineTensor(dK.data_ptr(), k_shape,
qkv_type, nullptr, nullptr, nullptr);
dqkv_type, nullptr, nullptr, nullptr);
te_dV = makeTransformerEngineTensor(dV.data_ptr(), v_shape,
qkv_type, nullptr, nullptr, nullptr);
dqkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
......
......@@ -4,7 +4,7 @@
"""Tensor class with FP8 data"""
from __future__ import annotations
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch.utils._pytree import tree_map
......@@ -233,6 +233,87 @@ class _IdentityFunc(torch.autograd.Function):
def backward(ctx, grad):
return grad.to(ctx.input_dtype), None
class _ViewFunc(torch.autograd.Function):
"""View function
View the Float8Tensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
shape: Tuple[int] = None,
) -> torch.Tensor:
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
return tensor
# Construct new tensor if shape is provided
if isinstance(tensor, Float8Tensor):
return Float8Tensor.make_like(
tensor,
data=tensor._data.view(*shape),
)
return tensor.view(*shape)
@staticmethod
def backward(ctx,
grad: torch.Tensor,
) -> Tuple[[torch.Tensor, None], ...]:
if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like(
grad,
data=grad._data.view(ctx.shape),
)
return dgrad, None
return grad.view(ctx.shape), None
class _ReshapeFunc(torch.autograd.Function):
"""Reshape function
Reshape the Float8Tensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
shape: Tuple[int] = None,
) -> torch.Tensor:
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
return tensor
# Construct new tensor if shape is provided
if isinstance(tensor, Float8Tensor):
return Float8Tensor.make_like(
tensor,
data=tensor._data.reshape(*shape),
)
return tensor.reshape(*shape)
@staticmethod
def backward(ctx,
grad: torch.Tensor,
) -> Tuple[Union[torch.Tensor, None], ...]:
if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like(
grad,
data=grad._data.reshape(ctx.shape),
)
return dgrad, None
return grad.reshape(ctx.shape), None
class Float8Tensor(torch.Tensor):
"""Experimental tensor class with FP8 data
......@@ -453,6 +534,12 @@ class Float8Tensor(torch.Tensor):
def clone(self) -> Float8Tensor:
return _IdentityFunc.apply(self, {"data": self._data.detach().clone()})
def view(self, *shape: Tuple[int]) -> Float8Tensor:
return _ViewFunc.apply(self, shape)
def reshape(self, *shape: Tuple[int]) -> Float8Tensor:
return _ReshapeFunc.apply(self, shape)
def expand_as(self, other: torch.Tensor):
if other is self:
# Note: expand_as is hackily used to create dummy autograd nodes
......
......@@ -202,6 +202,11 @@ class FP8GlobalStateManager:
# `fp8_param_to_autocast`. This is used for keeping track of FP8 weights
# in an autocasted region and cross reference them in `float8_tensor.py`
# to perform the forward amax reduction.
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
if fp8_meta_tensor_key not in fp8_meta:
# Handles non-parameter FP8 modules, e.g. DPA.
continue
if forward and fp8_weights is not None:
autocast_key = cls.get_unique_autocast_key(
fp8_meta["recipe"], fp8_meta["fp8_group"])
......@@ -217,7 +222,6 @@ class FP8GlobalStateManager:
key = cls.get_key_in_buffer(
forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"])
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
if key not in cls.global_amax_buffer:
cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
......
......@@ -268,6 +268,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)
for meta_key in fp8_meta_tensor_keys:
if meta_key not in self.fp8_meta:
# Handles non-parameter FP8 modules, e.g. DPA.
continue
curr_len = self.fp8_meta[meta_key].amax_history.shape[0]
if length == curr_len:
continue
......@@ -568,6 +571,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
inp: torch.Tensor,
is_first_microbatch: Union[bool, None],
num_gemms: int = 1,
allow_non_contiguous: bool = False,
) -> Generator[torch.Tensor, None, None]:
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
......@@ -610,7 +614,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
yield inp.contiguous()
if not allow_non_contiguous:
yield inp.contiguous()
else:
yield inp
if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
......@@ -645,8 +652,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
R4: bias gradient on R1.
"""
grad_output = grad_output.contiguous()
grad_output_mat = grad_output.view((-1, grad_output.shape[-1]))
if isinstance(grad_output, Float8Tensor):
grad_output._data = grad_output._data.contiguous()
else:
grad_output = grad_output.contiguous()
grad_output_mat = grad_output.view(-1, grad_output.shape[-1])
gather_grad_output = row_parallel_mode and ctx.sequence_parallel
# No-FP8 case: bgrad is fused with wgrad for this case.
......@@ -684,16 +694,22 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0)
else:
grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8)
cast_to_fp8(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
out=grad_output_c,
)
if not isinstance(grad_output_mat, Float8Tensor):
cast_to_fp8(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
out=grad_output_c,
)
else:
grad_output_c = grad_ouput_mat # pylint: disable=undefined-variable
if not ctx.ub_overlap_ag:
grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
if not isinstance(grad_output_c, Float8Tensor):
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
else:
grad_output_t = grad_output_c.transpose_2d()
else:
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1)
grad_output_t = None
......@@ -702,28 +718,38 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# FP8 case without gather: cast, transpose, bgrad fused
if ctx.use_bias:
grad_output_mat_no_fp8 = grad_output_mat
if isinstance(grad_output_mat, Float8Tensor):
grad_output_mat_no_fp8 = grad_output_mat.from_float8(grad_output_mat.dtype)
grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused(
grad_output_mat,
grad_output_mat_no_fp8,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
else:
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
grad_output_c, grad_output_t = fp8_cast_transpose_fused(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
if isinstance(grad_output_mat, Float8Tensor):
grad_output_c = grad_output_mat
grad_output_t = grad_output_c.transpose_2d()
else:
grad_output_c, grad_output_t = fp8_cast_transpose_fused(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
else:
grad_output_t = None
grad_output_c = cast_to_fp8(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
if not isinstance(grad_output_mat, Float8Tensor):
grad_output_c = cast_to_fp8(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
else:
grad_output_c = grad_output_mat
grad_bias = None
return grad_output_mat, grad_output_c, grad_output_t, grad_bias
......
......@@ -43,6 +43,7 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
__all__ = ["LayerNormLinear"]
......@@ -190,6 +191,9 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out = ln_out_total
if fp8:
if _NVTE_DEBUG:
print('[LayerNormLinear]: using FP8 forward')
bias_dtype = (
torch.bfloat16
if activation_dtype == torch.float32
......@@ -230,6 +234,15 @@ class _LayerNormLinear(torch.autograd.Function):
)
weight_t_fp8 = None
if fp8_meta["recipe"].fp8_mha:
out_index, meta_tensor, output_te_dtype, output_dtype = (
tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_meta["scaling_fwd"],
fp8_dtype_forward,
torch.uint8)
else:
out_index, meta_tensor, output_te_dtype, output_dtype = (
None, None, None, activation_dtype)
out, _ = tex.fp8_gemm(
weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv,
......@@ -239,7 +252,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
activation_dtype,
output_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
......@@ -247,8 +260,22 @@ class _LayerNormLinear(torch.autograd.Function):
ub_algo=ub_algo if ub_overlap_ag else None,
ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None,
out_index=out_index,
fp8_meta_tensor=meta_tensor,
D_dtype=output_te_dtype,
)
if output_dtype == torch.uint8:
out = Float8Tensor(data=out,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_dtype=fp8_dtype_forward,
dtype=activation_dtype,
)
else:
if _NVTE_DEBUG:
print('[LayerNormLinear]: using non-FP8 forward')
# Cast for native AMP
weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
......@@ -338,7 +365,6 @@ class _LayerNormLinear(torch.autograd.Function):
# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp.shape[1:-1], out.shape[-1])
if return_layernorm_output:
if return_layernorm_output_gathered:
shape = list(inp.shape)
......@@ -352,6 +378,10 @@ class _LayerNormLinear(torch.autograd.Function):
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
if isinstance(grad_outputs[0], Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[0]._scale_inv
with torch.cuda.nvtx.range("_LayerNormLinear_backward"):
(
inputmat,
......@@ -465,6 +495,9 @@ class _LayerNormLinear(torch.autograd.Function):
ub_obj = None
if ctx.fp8:
if _NVTE_DEBUG:
print('[LayerNormLinear]: using FP8 backward')
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
)
......@@ -486,7 +519,8 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
grad_output_c,
grad_output_c._data
if isinstance(grad_output_c, Float8Tensor) else grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
......@@ -503,6 +537,9 @@ class _LayerNormLinear(torch.autograd.Function):
)
clear_tensor_data(grad_output_c)
else:
if _NVTE_DEBUG:
print('[LayerNormLinear]: using non-FP8 backward')
# DGRAD: Evaluated unconditionally to feed into Linear backward
_, _, _ = tex.gemm(
weight,
......@@ -551,7 +588,8 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
grad_output_t,
grad_output_t._data
if isinstance(grad_output_t, Float8Tensor) else grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment