Unverified Commit 83a4c219 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[C/PyTorch] Add FP8 DPA and MHA (#768)



* WIP: fp8 v1 fprop integration
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add more debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fprop working for h1; w/ debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: add bprop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* cleanup; bprop running but has mismatches
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add gitlab frontend as submodule
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up and add back v0.9.2 FE support; fprop/bprop passing with 5e-2 tols
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix after merge; add bias_b/h to caching descriptor
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* distinguish fwd/bwd tensor types for bprop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for F16 cases; include added dqkv_type and d_scale_dp
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* adjust out shape for bwd in test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add casting from/to FP8 to DPA module
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: bshd_bshd_bshd layout
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: support all sbhd/bshd layouts
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add qkvpacked and kvpacked support in both FusedAttnFunc and C levels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove qkvpacked/kvpacked calls in DPA module (used for testing)
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove tp setup; add allow_non_contiguous; update FE; revert to sbh3d in tests; clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add NVTE_FP8_DPA_BWD to control whether to use FP8 bwd or F16 bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix MQA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix MQA/GQA in FP8 v1 API
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to 705d8e3, with API change
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test causal mask
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* restrict mha_fill for THD format
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fused attn with CP and comment out is_alibi code
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up FE0.9 vs FE1.0 FP8 implementations, and related unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change NVTE_FP8_DPA_BWD default to 1, and fix its use in qkvpacked/kvpacked APIs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint and self.tp_size/group in FusedAttention()
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to 6902c94
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add FP8 MHA support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update to FE v1.3.0
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for FP8 MHA with different configs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* emit stats regardless of is_training
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix linear when input is not Float8Tensor
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix d_out type when f16 bprop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix user buffer for layernorm_linear/linear and revert two FP8 casts in MHA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add docstring for fp8_dpa/mha in recipe
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix backend selection to avoid FA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace transpose with transpose_2d
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use RMSE for FP8 unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace two more transpose with transpose_2d
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add FP8 initialization to FusedAttention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rm docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Revert "add FP8 initialization to FusedAttention"

This reverts commit 15fffd825d6f23f31ea709b16ba01dfd61efabf8.
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change order of ctxs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back docs and mark as beta
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for tests and docs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f69e45be
Subproject commit a86ad708db725e4d29919bb6fadf8e6cdfa5dc06 Subproject commit 1b0b5eac540b7f8fd19b18f1e6b8427c95503348
...@@ -6,7 +6,7 @@ set -e ...@@ -6,7 +6,7 @@ set -e
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5 onnxruntime==1.13.1 pip install pytest==7.2 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import math
import functools import functools
from importlib.metadata import version from importlib.metadata import version
import os import os
...@@ -12,9 +13,10 @@ import pytest ...@@ -12,9 +13,10 @@ import pytest
import torch import torch
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init
from transformer_engine.pytorch.attention import ( from transformer_engine.pytorch.attention import (
DotProductAttention, DotProductAttention,
MultiheadAttention,
RotaryPositionEmbedding, RotaryPositionEmbedding,
) )
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
...@@ -939,52 +941,415 @@ def _run_transformer_layer( ...@@ -939,52 +941,415 @@ def _run_transformer_layer(
return out, inp.grad return out, inp.grad
model_configs_fp8 = { model_configs_fp8_vs_f16 = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_1": ModelConfig(1, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), "fp8_9 ": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), "fp8_10": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_11": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"),
"fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),
} }
param_types_fp8 = [torch.float16] param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
qkv_layout_fp8_vs_f16 = ['sbh3d', 'bshd_bshd_bshd', 'sbhd_sbhd_sbhd']
qkv_format_fp8_vs_f16 = ['bshd', 'sbhd']
def _rmse(a, b):
return math.sqrt((torch.pow((a-b), 2)/a.numel()).sum())
@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.") @pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.parametrize("dtype", param_types_fp8) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8.keys()) @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
def test_dpa_fp8(dtype, model): @pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
"""Test FP8 dot product attention @pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd):
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fp8_vs_f16[model]
FusedAttention uses fused_attn_fwd/bwd_qkvpacked from cpp_extensions, os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
and UnfusedDotProductAttention uses plain PyTorch operations in FP16 if _NVTE_DEBUG:
and converts inputs/outputs from/to FP8. print()
print("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm)
if _NVTE_DEBUG:
print()
print("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm)
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(),
fused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(),
fused_attn_fwd_f16.min().item())
if _NVTE_DEBUG:
print()
print('========== {:^25s} =========='.format('forward output'))
print('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item()))
print('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()))
print('fused_attn_fwd RMSE: {:.6f}'.format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
except Exception as e:
print(e)
print()
assert(fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range)
for i in range(len(param_names[:1])):
bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])
bwd_range = max(fused_attn_bwd_fp8[i].max().item(),
fused_attn_bwd_f16[i].max().item()) - min(fused_attn_bwd_fp8[i].min().item(),
fused_attn_bwd_f16[i].min().item())
if _NVTE_DEBUG:
print()
print('========== {:^25s} =========='.format(param_names[i]))
print('fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}'.format(i,
fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()))
print('fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}'.format(i,
fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()))
print('fused_attn_bwd RMSE[{}]: {:.6f}'.format(i, bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e:
print(e)
print()
assert(bwd_rmse < rmse_tol * bwd_range
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range)
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
""" fp8_recipe = recipe.DelayedScaling(
margin=0,
interval=1,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=fp8_mha,
fp8_mha=fp8_mha,
)
config = model_configs_fp8[model] with fp8_model_init(enabled=fp8_mha):
mha = (MultiheadAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_heads,
kv_channels=config.head_dim,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
layer_number=1,
bias=True,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
input_layernorm=input_layernorm,
fuse_qkv_params=True,
attention_type="self",
qkv_weight_interleaved=True,
qkv_format=qkv_format,
).to(dtype=dtype, device="cuda")
)
# Skip if not supported seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( dtype=torch.int32, device="cuda")
config, dtype) seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
if not fused_attn_supported: dtype=torch.int32, device="cuda")
pytest.skip("FusedAttention does not support this model config") cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
# Run dot-product attention with different backends dim_to_num = {
fused_attn_fwd, fused_attn_bwd = _run_dpa_fp8( 'b' : config.batch_size,
dtype, config, "FusedAttention") 'sq' : config.max_seqlen_q,
unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref( 'skv': config.max_seqlen_kv,
dtype, config, "UnfusedDotProductAttention") 'h' : config.num_heads,
'hg' : config.num_gqa_groups,
'd' : config.head_dim,
't' : cu_seqlens_q[-1],
'tg' : cu_seqlens_kv[-1],
'3' : 3,
'2' : 2,
'1' : 1,
}
layout = '_'.join(qkv_format)
layout = layout.replace('s', 'sq')
tensor_shape = [dim_to_num[j] for j in layout.split('_')]
tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
hidden_states = tensor.view(*tensor.shape[:-2], -1)
hidden_states.requires_grad = True
tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
out_grad = tensor.view(*tensor.shape[:-2], -1)
with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe):
out = mha(hidden_states,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=None,
)
out.backward(out_grad)
tols = dict(atol=2.5e-2, rtol=2.5e-2) param_names = []
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) param_names.append('hidden_states.grad')
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols) params = []
params.append(hidden_states)
for name, param in mha.named_parameters():
if param.requires_grad:
param_names.append(name+'.grad')
params.append(param)
return out, param_names, tuple(x.grad for x in params)
@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
config = model_configs_fp8_vs_f16[model]
if (config.num_heads != config.num_gqa_groups and '3' in qkv_layout):
pytest.skip("qkv_layout not applicable for MQA/GQA");
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
if _NVTE_DEBUG:
print()
print("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout)
if _NVTE_DEBUG:
print("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False")
fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
dtype, config, False, qkv_layout)
tols = dict(atol=5e-1, rtol=5e-2)
if _NVTE_DEBUG:
print('[test_dpa_fp8_vs_f16]: ', tols)
print('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item()))
print('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()))
print('fused_attn_fwd RMSE: {:.6f}'.format(
_rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)))
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
for i,_ in enumerate(fused_attn_bwd_f16):
if _NVTE_DEBUG:
print('fused_attn_bwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()))
print('fused_attn_bwd_f16 min {:.6f} max {:.6f}'.format(
fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()))
print('fused_attn_bwd RMSE: {:.6f}'.format(
_rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])))
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
fp8_recipe = recipe.DelayedScaling(
margin=0,
interval=1,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=fp8_dpa,
)
qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
with fp8_model_init(enabled=fp8_dpa):
dpa = (
DotProductAttention(
config.num_heads,
config.head_dim,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type="self",
qkv_format=qkv_format,
).to(dtype=dtype, device="cuda")
)
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda")
seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
dtype=torch.int32, device="cuda")
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
dim_to_num = {
'b' : config.batch_size,
'sq' : config.max_seqlen_q,
'skv': config.max_seqlen_kv,
'h' : config.num_heads,
'hg' : config.num_gqa_groups,
'd' : config.head_dim,
't' : cu_seqlens_q[-1],
'tg' : cu_seqlens_kv[-1],
'3' : 3,
'2' : 2,
'1' : 1,
}
inp = []
for i,layout in enumerate(qkv_layout.split('_')):
layout = '_'.join(layout)
if i == 0:
layout = layout.replace('s', 'sq')
else:
layout = layout.replace('s', 'skv')
layout = layout.replace('h', 'hg')
layout = layout.replace('t', 'tg')
tensor_shape = [dim_to_num[j] for j in layout.split('_')]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_count = 1
split_dim = 0
for dim, l in enumerate(layout.split('_')):
if l.isdigit():
tensor_count = int(l)
split_dim = dim
break
tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
for j in range(tensor_count):
if split_dim != 0:
inp.append(tensors[j].squeeze(split_dim))
else:
inp.append(tensors[j])
for i in range(3):
inp[i].requires_grad = True
qkv_format_kv = '_'.join(qkv_format)
qkv_format_kv = qkv_format_kv.replace('s', 'sq')
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = 0.1 * torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
with fp8_autocast(enabled=fp8_dpa, fp8_recipe=fp8_recipe):
out = dpa(inp[0], inp[1], inp[2],
qkv_format=qkv_format,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=True,
)
out.backward(out_grad)
return out, (inp[0].grad, inp[1].grad, inp[2].grad)
def _run_dpa_fp8(dtype, config, backend):
"""Run FusedAttention FP8 backend, i.e.
fused_attn_fwd/bwd_qkvpacked from cpp_extensions"""
model_configs_fp8 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_4": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"),
"fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"),
"fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
}
param_types_fp8 = [torch.float16, torch.bfloat16]
cudnn_frontend_version = int(os.getenv('NVTE_FUSED_ATTN_FE_VER','1'))
models_v0 = ['fp8_1', 'fp8_2', 'fp8_5', 'fp8_6']
models_v1 = ['fp8_3', 'fp8_4', 'fp8_7', 'fp8_8']
@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
def test_custom_mha_fp8_vs_f16(dtype, model):
"""Test FP8 dot product attention implementations based on cuDNN frontend
v0.9 and v1.0+. Each test compares results from a custom implementation of
an FP8 MHA module, i.e. Custom_MHA_FP8(), to results from an F16 MHA
implementation, i.e. transformer_engine.pytorch.attention.MultiHeadAttention.
Both paths take F16 input and output. QKV layout is t3hd or bs3hd"""
config = model_configs_fp8[model]
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(
dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(
dtype, config, "UnfusedAttention")
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, unfused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(),
unfused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(),
unfused_attn_fwd_f16.min().item())
bwd_rmse = _rmse(fused_attn_bwd_fp8, unfused_attn_bwd_f16)
bwd_range = max(fused_attn_bwd_fp8.max().item(),
unfused_attn_bwd_f16.max().item()) - min(fused_attn_bwd_fp8.min().item(),
unfused_attn_bwd_f16.min().item())
if _NVTE_DEBUG:
print('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item()))
print('unfused_attn_fwd_f16 min {:.6f} max {:.6f}'.format(
unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item()))
print('fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}'.format(
fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, unfused_attn_fwd_f16, **tols)
except Exception as e:
print(e)
print()
print('fused_attn_bwd_fp8 min {:.6f} max {:.6f}'.format(
fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item()))
print('unfused_attn_bwd_f16 min {:.6f} max {:.6f}'.format(
unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item()))
print('fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}'.format(
bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8, unfused_attn_bwd_f16, **tols)
except Exception as e:
print(e)
print()
assert(fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range)
assert(bwd_rmse < rmse_tol * bwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range)
def _run_custom_mha_fp8(dtype, config, backend):
"""Run Custom_MHA_FP8 with FP8 FusedAttention backend. Both input and output
are in F16. QKV GEMM, DPA, and projection GEMM are calculated in FP8."""
reset_rng_states() reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
...@@ -993,13 +1358,14 @@ def _run_dpa_fp8(dtype, config, backend): ...@@ -993,13 +1358,14 @@ def _run_dpa_fp8(dtype, config, backend):
if backend == "FusedAttention": if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
inp = 0.01 * torch.randn( inp = 0.0001 * torch.randint(0, 100,
config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim, (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim),
dtype=dtype, device="cuda", requires_grad=True) dtype=dtype, device="cuda", requires_grad=True)
seqlens = torch.full([config.batch_size], config.max_seqlen_q, seqlens = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda") dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32) cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
out_grad = 0.01 * torch.randn( out_grad = 0.01 * torch.randn(
config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim, config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim,
dtype=dtype, device="cuda") dtype=dtype, device="cuda")
...@@ -1013,22 +1379,21 @@ def _run_dpa_fp8(dtype, config, backend): ...@@ -1013,22 +1379,21 @@ def _run_dpa_fp8(dtype, config, backend):
amax_compute_algo="most_recent", amax_compute_algo="most_recent",
) )
dpa = DPA_FP8(config).to(dtype=torch.float16, device="cuda") mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
out = dpa(inp, cu_seqlens, config.max_seqlen_q) out = mha(inp, cu_seqlens, config.max_seqlen_q)
out.backward(out_grad) out.backward(out_grad)
context = torch.load("ctx.pt") out = torch.load("out.pt")
dqkv = torch.load('dqkv.pt') dqkv = torch.load('dqkv.pt')
return (context.view(config.batch_size, config.max_seqlen_q, -1).transpose(0,1), return (out.view(config.batch_size, config.max_seqlen_q, -1),
dqkv.view(config.batch_size, config.max_seqlen_q, 3, dqkv.view(config.batch_size, config.max_seqlen_q, 3,
config.num_heads, config.head_dim).transpose(0,1).contiguous()) config.num_heads, config.head_dim).contiguous())
def _run_dpa_fp8_ref(dtype, config, backend): def _run_ref_mha_f16(dtype, config, backend):
"""Run UnfusedDotProductAttention as a reference, i.e. """Run reference F16 FusedAttention. Both input and output
plain PyTorch implementation in FP16 and inputs/outputs are in F16. QKV GEMM, DPA, and projection GEMM are also in F16."""
are converted from/to FP8"""
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
...@@ -1043,7 +1408,7 @@ def _run_dpa_fp8_ref(dtype, config, backend): ...@@ -1043,7 +1408,7 @@ def _run_dpa_fp8_ref(dtype, config, backend):
cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32) cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
out_grad = torch.load('out_grad.pt').to(device="cuda").view( out_grad = torch.load('out_grad.pt').to(device="cuda").view(
config.batch_size, config.max_seqlen_q, -1).transpose(0,1) config.batch_size, config.max_seqlen_q, -1)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
...@@ -1069,13 +1434,14 @@ def _run_dpa_fp8_ref(dtype, config, backend): ...@@ -1069,13 +1434,14 @@ def _run_dpa_fp8_ref(dtype, config, backend):
get_rng_state_tracker=get_dummy_cuda_rng_tracker, get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None, tp_group=None,
layer_number=1, layer_number=1,
attention_type="self" attention_type="self",
qkv_format="bshd",
).to(dtype=dtype, device="cuda") ).to(dtype=dtype, device="cuda")
) )
q = inp[:, :,0,:,:] q = inp[:,:,0,:,:]
k = inp[:, :,1,:,:] k = inp[:,:,1,:,:]
v = inp[:, :,2,:,:] v = inp[:,:,2,:,:]
out = block(q, k, v, attn_mask_type=config.attn_mask_type) out = block(q, k, v, attn_mask_type=config.attn_mask_type)
out.backward(out_grad) out.backward(out_grad)
...@@ -1088,14 +1454,14 @@ _2X_ACC_DGRAD = False ...@@ -1088,14 +1454,14 @@ _2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False _2X_ACC_WGRAD = False
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_O = tex.FP8FwdTensors.GEMM2_INPUT META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2 META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
META_S = tex.FP8FwdTensors.GEMM3_WEIGHT
META_DS = tex.FP8BwdTensors.GRAD_INPUT3
class _dpa_fp8(torch.autograd.Function): class _custom_mha_fp8(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
...@@ -1110,6 +1476,7 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -1110,6 +1476,7 @@ class _dpa_fp8(torch.autograd.Function):
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
workspace: torch.Tensor, workspace: torch.Tensor,
is_training: bool, is_training: bool,
mask_type: str,
) -> torch.Tensor: ) -> torch.Tensor:
assert inp.dim() == 2 assert inp.dim() == 2
...@@ -1117,14 +1484,10 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -1117,14 +1484,10 @@ class _dpa_fp8(torch.autograd.Function):
h = num_heads h = num_heads
d = in_features // h d = in_features // h
b = cu_seqlens.numel() - 1 b = cu_seqlens.numel() - 1
is_nl = False
if b < 4 and b > 1:
max_s = 512
is_nl = True
fp8_dtype_forward = fp8.get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = fp8.get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
inputmat, inputmat_t = ext.fp8_cast_transpose_fused( inp_fp8, inp_t_fp8 = ext.fp8_cast_transpose_fused(
inp, inp,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
...@@ -1142,12 +1505,12 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -1142,12 +1505,12 @@ class _dpa_fp8(torch.autograd.Function):
ZInv = None ZInv = None
philox_unpacked = None philox_unpacked = None
qkv_out, _ = ext.fp8_gemm( qkv, _ = ext.fp8_gemm(
qkv_weight_fp8, qkv_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
inputmat, inp_fp8,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -1160,26 +1523,29 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -1160,26 +1523,29 @@ class _dpa_fp8(torch.autograd.Function):
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
D_dtype=fp8_dtype_forward, D_dtype=fp8_dtype_forward,
) )
qkv_out = qkv_out.view(-1, 3, h, d) qkv = qkv.view(-1, 3, h, d)
qkv_out_fp16 = ext.cast_from_fp8(qkv_out, fp8_meta["scaling_fwd"], qkv_fp16 = ext.cast_from_fp8(qkv, fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, META_QKV, fp8_dtype_forward,
tex.DType.kFloat16).view(b, max_s, 3, h, d).transpose(0,1).contiguous() tex.DType.kFloat16).view(b, max_s, 3, h, d).contiguous()
torch.save(qkv_out_fp16, 'qkv.pt') torch.save(qkv_fp16, 'qkv.pt')
if cudnn_frontend_version == 1:
qkv = qkv.view(b, max_s, 3, h, d) # bs3hd
# FMHA # FMHA
context_, aux_ctx_tensors, *rest = fused_attn_fwd( out, aux_ctx_tensors, *rest = fused_attn_fwd(
is_training, is_training,
max_s, max_s,
max_s, max_s,
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
qkv_out[:,0,:,:], qkv[:,:,0,:,:] if cudnn_frontend_version == 1 else qkv[:,0,:,:],
qkv_out[:,1,:,:], qkv[:,:,1,:,:] if cudnn_frontend_version == 1 else qkv[:,1,:,:],
qkv_out[:,2,:,:], qkv[:,:,2,:,:] if cudnn_frontend_version == 1 else qkv[:,2,:,:],
fp8_dtype_forward, fp8_dtype_forward,
FusedAttnBackend["FP8"], FusedAttnBackend["FP8"],
None, None,
fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale[META_S],
fp8_meta["scaling_fwd"].scale[META_O], fp8_meta["scaling_fwd"].scale[META_O],
fp8_meta["scaling_fwd"].amax_history[0][META_S], fp8_meta["scaling_fwd"].amax_history[0][META_S],
...@@ -1187,20 +1553,17 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -1187,20 +1553,17 @@ class _dpa_fp8(torch.autograd.Function):
attn_scale=None, attn_scale=None,
dropout=p_dropout, dropout=p_dropout,
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
qkv_layout="t3hd", qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
attn_bias_type="no_bias", attn_bias_type="no_bias",
attn_mask_type="padding", attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
rng_gen=None, rng_gen=None,
) )
M, ZInv, philox_unpacked = aux_ctx_tensors
context = context_.view(-1, in_features) M, ZInv, philox_unpacked = aux_ctx_tensors
context_t = tex.fp8_transpose(context, fp8_dtype_forward)
ctx.save_for_backward( ctx.save_for_backward(
inputmat_t, qkv_weight_t_fp8, workspace, inp_t_fp8, qkv_weight_t_fp8, workspace,
qkv_out, qkv, out,
context_, context_t,
fp8_meta["scaling_fwd"].scale, fp8_meta["scaling_fwd"].scale,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
) )
...@@ -1210,14 +1573,16 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -1210,14 +1573,16 @@ class _dpa_fp8(torch.autograd.Function):
ctx.p_dropout = p_dropout ctx.p_dropout = p_dropout
ctx.max_s = max_s ctx.max_s = max_s
ctx.fast_zero_fill = fast_zero_fill ctx.fast_zero_fill = fast_zero_fill
ctx.is_nl = is_nl
ctx.hidden_size = in_features ctx.hidden_size = in_features
ctx.num_heads = num_heads ctx.num_heads = num_heads
ctx.mask_type = mask_type
ctx.dtype = inp.dtype
context_fp16 = ext.cast_from_fp8(context, fp8_meta["scaling_fwd"], out = out.view(-1, in_features) # (bs)(hd)
out_fp16 = ext.cast_from_fp8(out, fp8_meta["scaling_fwd"],
META_O, fp8_dtype_forward, tex.DType.kFloat16) META_O, fp8_dtype_forward, tex.DType.kFloat16)
torch.save(context_fp16, 'ctx.pt') torch.save(out_fp16, 'out.pt') # (bs)(hd)
return context_fp16 return out_fp16
@staticmethod @staticmethod
...@@ -1226,11 +1591,10 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -1226,11 +1591,10 @@ class _dpa_fp8(torch.autograd.Function):
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_DPA"): with torch.cuda.nvtx.range("_DPA"):
( (
inputmat_t, inp_t_fp8,
qkv_weight_t_fp8, qkv_weight_t_fp8,
workspace, workspace,
qkv_out, qkv, out,
context, context_t,
fwd_scales, fwd_scales,
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensors ) = ctx.saved_tensors
...@@ -1243,51 +1607,59 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -1243,51 +1607,59 @@ class _dpa_fp8(torch.autograd.Function):
proj_dgrad = ext.cast_to_fp8( proj_dgrad = ext.cast_to_fp8(
grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
) ) # (bs)(hd)
dq, dk, dv, *rest = fused_attn_bwd( dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_s, ctx.max_s,
ctx.max_s, ctx.max_s,
ctx.cu_seqlens, ctx.cu_seqlens,
ctx.cu_seqlens, ctx.cu_seqlens,
qkv_out[:,0,:,:], qkv[:,:,0,:,:] if cudnn_frontend_version == 1 else qkv[:,0,:,:],
qkv_out[:,1,:,:], qkv[:,:,1,:,:] if cudnn_frontend_version == 1 else qkv[:,1,:,:],
qkv_out[:,2,:,:], qkv[:,:,2,:,:] if cudnn_frontend_version == 1 else qkv[:,2,:,:],
context, out,
proj_dgrad.view_as(context), proj_dgrad.view_as(out),
fp8_dtype_forward, fp8_dtype_forward,
fp8_dtype_backward,
ctx.aux_ctx_tensors, ctx.aux_ctx_tensors,
FusedAttnBackend["FP8"], FusedAttnBackend["FP8"],
fwd_scale_inverses[META_QKV], # d_scale_qkv, fwd_scale_inverses[META_QKV], # d_scale_qkv,
fwd_scale_inverses[META_S], # d_scale_s, fwd_scale_inverses[META_S], # d_scale_s,
fwd_scale_inverses[META_O], # d_scale_o, fwd_scale_inverses[META_O], # d_scale_o,
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp
fwd_scales[META_S], # q_scale_s fwd_scales[META_S], # q_scale_s
ctx.fp8_meta['scaling_bwd'].scale[META_DS], # q_scale_ds ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp
ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DS], # amax_ds ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
None, attn_scale=None,
ctx.p_dropout, dropout=ctx.p_dropout,
ctx.fast_zero_fill, fast_zero_fill=ctx.fast_zero_fill,
"t3hd", qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
"no_bias", attn_bias_type="no_bias",
"padding", attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
) )
dqkv = torch.cat([dq.unsqueeze(1), dk.unsqueeze(1), dv.unsqueeze(1)], dim=1) dim = 2 if cudnn_frontend_version == 1 else 1
dqkv = torch.Tensor().to(device=dq.device, dtype=dq.dtype)
dqkv_grad_output_c = dqkv.view(-1, 3*ctx.hidden_size) dqkv_shape = list(dq.shape)
dqkv_grad_output_c_fp16 = ext.cast_from_fp8(dqkv_grad_output_c, dqkv_shape.insert(dim, 3)
dqkv_stride = list(dq.stride())
dqkv_stride.insert(dim, int(dqkv_stride[-3]/3))
dqkv.set_(dq.untyped_storage(), dq.storage_offset(), dqkv_shape, dqkv_stride) # bs3hd
dqkv_c = dqkv.view(-1, 3*ctx.hidden_size)
dqkv_c_fp16 = ext.cast_from_fp8(dqkv_c,
ctx.fp8_meta["scaling_bwd"], META_DQKV, ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, tex.DType.kFloat16) fp8_dtype_backward, tex.DType.kFloat16)
torch.save(dqkv_grad_output_c_fp16, 'dqkv.pt') torch.save(dqkv_c_fp16, 'dqkv.pt')
qkv_bgrad, dqkv_grad_output_t = ext.fp8_transpose_bgrad_fused( qkv_bgrad, dqkv_t = ext.fp8_transpose_bgrad_fused(
dqkv_grad_output_c, dqkv_c,
ctx.fp8_meta["scaling_bwd"], ctx.fp8_meta["scaling_bwd"],
META_DQKV, META_DQKV,
fp8_dtype_backward, fp8_dtype_backward,
torch.float16, ctx.dtype,
) )
# QKV DGRAD # QKV DGRAD
...@@ -1296,25 +1668,25 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -1296,25 +1668,25 @@ class _dpa_fp8(torch.autograd.Function):
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
dqkv_grad_output_c, dqkv_c,
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
META_DQKV, META_DQKV,
fp8_dtype_backward, fp8_dtype_backward,
torch.float16, ctx.dtype,
workspace, workspace,
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
) )
# QKV WGRAD # QKV WGRAD
qkv_wgrad, _ = ext.fp8_gemm( qkv_wgrad, _ = ext.fp8_gemm(
inputmat_t, inp_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
dqkv_grad_output_t, dqkv_t,
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
META_DQKV, META_DQKV,
fp8_dtype_backward, fp8_dtype_backward,
torch.float16, ctx.dtype,
workspace, workspace,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
) )
...@@ -1334,7 +1706,7 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -1334,7 +1706,7 @@ class _dpa_fp8(torch.autograd.Function):
None) None)
class DPA_FP8(TransformerEngineBaseModule): class Custom_MHA_FP8(TransformerEngineBaseModule):
def __init__( def __init__(
self, self,
config, config,
...@@ -1345,6 +1717,7 @@ class DPA_FP8(TransformerEngineBaseModule): ...@@ -1345,6 +1717,7 @@ class DPA_FP8(TransformerEngineBaseModule):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_dim = config.head_dim self.head_dim = config.head_dim
self.fast_zero_fill = True self.fast_zero_fill = True
self.mask_type = config.attn_mask_type
self.qkv_weight = torch.nn.Parameter( self.qkv_weight = torch.nn.Parameter(
torch.empty( torch.empty(
...@@ -1374,7 +1747,7 @@ class DPA_FP8(TransformerEngineBaseModule): ...@@ -1374,7 +1747,7 @@ class DPA_FP8(TransformerEngineBaseModule):
cu_seqlens, max_s, cu_seqlens, max_s,
) -> torch.Tensor: ) -> torch.Tensor:
with self.prepare_forward(inp, None, num_gemms=3) as inp: with self.prepare_forward(inp, None, num_gemms=3) as inp:
out = _dpa_fp8.apply( out = _custom_mha_fp8.apply(
inp, inp,
self.qkv_weight, self.qkv_weight,
self.qkv_bias, self.qkv_bias,
...@@ -1385,7 +1758,8 @@ class DPA_FP8(TransformerEngineBaseModule): ...@@ -1385,7 +1758,8 @@ class DPA_FP8(TransformerEngineBaseModule):
self.fast_zero_fill, self.fast_zero_fill,
self.fp8_meta, self.fp8_meta,
self.workspace, self.workspace,
self.training) self.training,
self.mask_type)
return out return out
def get_fp8_weights_scratchpad( def get_fp8_weights_scratchpad(
......
...@@ -1091,7 +1091,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere ...@@ -1091,7 +1091,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)
# Check output. # Check output.
atol = {torch.float32 : 2e-4, atol = {torch.float32 : 2.5e-4,
torch.half : 2e-3, torch.half : 2e-3,
torch.bfloat16: 2e-2, torch.bfloat16: 2e-2,
} }
......
...@@ -85,15 +85,25 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -85,15 +85,25 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
auto cudnn_runtime_version = cudnnGetVersion(); auto cudnn_runtime_version = cudnnGetVersion();
if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2) if (((q_dtype == NVTEDType::kNVTEFloat8E4M3)
|| (q_dtype == NVTEDType::kNVTEFloat8E5M2))
&& (sm_arch_ >= 90) && (sm_arch_ >= 90)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (
((cudnn_runtime_version >= 8900)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD)
&& (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q == max_seqlen_kv)
&& (num_attn_heads == num_gqa_groups)
&& (max_seqlen_q <= 512) && (max_seqlen_q <= 512)
&& (head_dim == 64) && (head_dim == 64)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK))
&& (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || ((cudnn_runtime_version >= 90100)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD)) { && (max_seqlen_q % 128 == 0)
&& (max_seqlen_kv % 128 == 0)
&& (head_dim == 128)
&& ((qkv_format == NVTE_QKV_Format::NVTE_BSHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD))
&& ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) {
if (cudnn_runtime_version >= 8900) { if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8; backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else { } else {
...@@ -269,7 +279,7 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -269,7 +279,7 @@ void nvte_fused_attn_fwd_qkvpacked(
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd_qkvpacked( fused_attn_fp8_fwd_qkvpacked(
b, h, max_seqlen, d, b, h, max_seqlen, d,
is_training, attn_scale, dropout, qkv_layout, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_output_S, output_O, input_QKV, input_output_S, output_O,
Aux_CTX_Tensors, Aux_CTX_Tensors,
input_cu_seqlens, input_cu_seqlens,
...@@ -379,7 +389,7 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -379,7 +389,7 @@ void nvte_fused_attn_bwd_qkvpacked(
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_qkvpacked( fused_attn_fp8_bwd_qkvpacked(
b, h, max_seqlen, d, b, h, max_seqlen, d,
attn_scale, dropout, qkv_layout, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_O, input_dO, input_QKV, input_O, input_dO,
input_M, input_ZInv, input_M, input_ZInv,
input_S, input_output_dP, input_S, input_output_dP,
...@@ -476,7 +486,18 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -476,7 +486,18 @@ void nvte_fused_attn_fwd_kvpacked(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); #if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_output_S, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
} else { } else {
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
...@@ -580,7 +601,23 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -580,7 +601,23 @@ void nvte_fused_attn_bwd_kvpacked(
NVTE_ERROR(err_msg); NVTE_ERROR(err_msg);
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); #if (CUDNN_VERSION >= 8900)
const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_O, input_dO,
input_M, input_ZInv,
input_S, input_output_dP,
output_dQ, output_dKV,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
} else { } else {
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
...@@ -662,8 +699,8 @@ void nvte_fused_attn_fwd( ...@@ -662,8 +699,8 @@ void nvte_fused_attn_fwd(
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd( fused_attn_fp8_fwd(
b, h_q, max_seqlen_q, max_seqlen_kv, d, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_output_S, output_O, input_Q, input_K, input_V, input_output_S, output_O,
Aux_CTX_Tensors, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q, input_cu_seqlens_kv,
...@@ -775,8 +812,8 @@ void nvte_fused_attn_bwd( ...@@ -775,8 +812,8 @@ void nvte_fused_attn_bwd(
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]); const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd( fused_attn_fp8_bwd(
b, h_q, max_seqlen_q, max_seqlen_kv, d, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_O, input_dO, input_Q, input_K, input_V, input_O, input_dO,
input_M, input_ZInv, input_M, input_ZInv,
input_S, input_output_dP, input_S, input_output_dP,
......
...@@ -76,7 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -76,7 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
scaling_factor, is_training, scaling_factor, is_training,
dropout_probability, layout, dropout_probability, layout,
bias_type, mask_type, bias_type, mask_type,
tensorType}; tensorType, tensorType};
namespace fe = cudnn_frontend; namespace fe = cudnn_frontend;
using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>, using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>,
...@@ -147,7 +147,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -147,7 +147,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
fe::graph::SDPA_attributes sdpa_options; fe::graph::SDPA_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_attributes() sdpa_options = fe::graph::SDPA_attributes()
.set_name("flash_attention") .set_name("flash_attention")
.set_is_inference(!is_training) .set_is_inference(false)
.set_causal_mask(is_causal) .set_causal_mask(is_causal)
.set_attn_scale(attn_scale); .set_attn_scale(attn_scale);
...@@ -199,11 +199,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -199,11 +199,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout, NVTE_QKV_Matrix::NVTE_O_Matrix); layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride);
if (is_training) {
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT) Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h, s_q, 1}) .set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1}); .set_stride({h * s_q, s_q, 1, 1});
}
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q
std::shared_ptr<fe::graph::Tensor_attributes>, // K std::shared_ptr<fe::graph::Tensor_attributes>, // K
...@@ -211,7 +209,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -211,7 +209,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes> > // O std::shared_ptr<fe::graph::Tensor_attributes> > // O
key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O); key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O);
auto Stats_tuple = is_training ? std::make_tuple(Stats) : std::make_tuple(nullptr); auto Stats_tuple = std::make_tuple(Stats);
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto padding_tuple = is_padding ? auto padding_tuple = is_padding ?
std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
...@@ -258,11 +256,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -258,11 +256,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
{K, devPtrK}, {K, devPtrK},
{V, devPtrV}, {V, devPtrV},
{attn_scale, &scaling_factor}, {attn_scale, &scaling_factor},
{O, devPtrO}}; {O, devPtrO},
{Stats, devPtrSoftmaxStats}};
if (is_training) {
variant_pack[Stats] = devPtrSoftmaxStats;
}
if (is_bias) { if (is_bias) {
variant_pack[bias] = devPtrBias; variant_pack[bias] = devPtrBias;
...@@ -321,7 +316,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -321,7 +316,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
scaling_factor, true, scaling_factor, true,
dropout_probability, layout, dropout_probability, layout,
bias_type, mask_type, bias_type, mask_type,
tensorType}; tensorType, tensorType};
namespace fe = cudnn_frontend; namespace fe = cudnn_frontend;
using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>, using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>,
......
...@@ -19,7 +19,7 @@ namespace transformer_engine { ...@@ -19,7 +19,7 @@ namespace transformer_engine {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked( void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t batch, size_t num_attn_heads, size_t max_seqlen,
size_t head_size, bool is_training, float attn_scale, size_t head_dim, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_QKV, const Tensor *input_Bias,
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "../common.h" #include "../common.h"
#include "utils.h" #include "utils.h"
#include "../util/system.h"
#include "fused_attn_fp8.h" #include "fused_attn_fp8.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -984,7 +985,7 @@ static cudnn_frontend::Tensor createdSQBMM( ...@@ -984,7 +985,7 @@ static cudnn_frontend::Tensor createdSQBMM(
return After_dSTranspose_Q; return After_dSTranspose_Q;
} }
// fused attention FWD FP8 // fused attention FWD FP8 with FE 0.9
void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
bool isTraining, float attnScale, bool isTraining, float attnScale,
float dropoutProbability, NVTE_QKV_Layout layout, float dropoutProbability, NVTE_QKV_Layout layout,
...@@ -1295,7 +1296,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in ...@@ -1295,7 +1296,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
} }
} }
// fused attention BWD FP8 // fused attention BWD FP8 with FE 0.9
void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
float attnScale, float dropoutProbability, NVTE_QKV_Layout layout, float attnScale, float dropoutProbability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrQ, void* devPtrK, void* devPtrV,
...@@ -1846,6 +1847,707 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in ...@@ -1846,6 +1847,707 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
} }
} }
// fused attention FWD FP8 with FE 1.0+
void fused_attn_fp8_fwd_impl_v1(int64_t b, int64_t h, int64_t hg,
int64_t s_q, int64_t s_kv, int64_t d,
bool is_training, float scaling_factor,
float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
void* devPtrQ, void* devPtrK, void* devPtrV,
void* devPtrM, void* devPtrZInv,
void* devPtrO,
void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV,
void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO,
void* devPtrAmaxO, void* devPtrAmaxS,
void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV,
void* devPtrDropoutSeed, void* devPtrDropoutOffset,
cudnn_frontend::DataType_t fwd_tensor_type,
void* workspace,
size_t* workspace_size,
cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (is_training && dropout_probability != 0.0f);
auto bias_b = b;
auto bias_h = h;
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
NVTE_CHECK(~is_padding,
"FP8 fused attention does not support padding/padding_causal mask yet!");
NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!");
try {
FADescriptor_v1 descriptor{b, h,
hg, s_q,
s_kv, d,
bias_b, bias_h,
scaling_factor, is_training,
dropout_probability, layout,
bias_type, mask_type,
fwd_tensor_type, fwd_tensor_type};
namespace fe = cudnn_frontend;
using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>,
std::shared_ptr<fe::graph::Tensor_attributes>, // Q
std::shared_ptr<fe::graph::Tensor_attributes>, // K
std::shared_ptr<fe::graph::Tensor_attributes>, // V
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_q
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_k
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_v
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_s
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_s
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_o
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes>, // O
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_s
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_o
std::shared_ptr<fe::graph::Tensor_attributes>, // Stats
std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // dropout_seed
std::shared_ptr<fe::graph::Tensor_attributes> >; // dropout_offset
using CacheType = std::map<FADescriptor_v1, graph_and_tensors>;
static thread_local CacheType sdpa_fp8_fprop_cache;
// Get plan from cache if cache is available, otherwise create one
auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor)
-> graph_and_tensors {
// if hit, return
auto it = cache.find(descriptor);
if (it != cache.end()) {
auto graph = it->second;
return graph;
}
// otherwise, build the op_graph and the plan. Then update cache
auto mha_graph = std::make_shared<fe::graph::Graph>();
mha_graph->set_io_data_type(fwd_tensor_type)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
std::shared_ptr<fe::graph::Tensor_attributes> Q, K, V, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> descale_q, descale_k, descale_v;
std::shared_ptr<fe::graph::Tensor_attributes> descale_s, scale_s, scale_o;
std::shared_ptr<fe::graph::Tensor_attributes> bias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d})
.set_stride(q_stride));
K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride));
V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride));
attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
descale_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Descale_q")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
descale_k = mha_graph->tensor_like(descale_q, "Descale_q");
descale_v = mha_graph->tensor_like(descale_q, "Descale_V");
descale_s = mha_graph->tensor_like(descale_q, "Descale_S");
scale_s = mha_graph->tensor_like(descale_q, "Scale_S");
scale_o = mha_graph->tensor_like(descale_q, "Scale_O");
fe::graph::SDPA_fp8_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_fp8_attributes()
.set_name("sdpa_fp8")
.set_is_inference(false)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
// sdpa_options.set_alibi_mask(is_alibi);
// if (is_bias) {
// bias = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("bias")
// .set_dim({bias_b, bias_h, s_q, s_kv})
// .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
// sdpa_options.set_bias(bias);
// }
// if (is_padding) {
// seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("seq_q")
// .set_dim({b, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT32));
// seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("seq_kv")
// .set_dim({b, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT32));
// sdpa_options.set_padding_mask(is_padding)
// .set_seq_len_q(seq_q)
// .set_seq_len_kv(seq_kv);
// }
// if (is_dropout) {
// dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("Seed")
// .set_dim({1, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT64));
// dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("Offset")
// .set_dim({1, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT64));
// sdpa_options.set_dropout(
// dropout_probability, dropout_seed, dropout_offset);
// }
auto [O, Stats, amax_s, amax_o] = mha_graph->sdpa_fp8(
Q, K, V, descale_q, descale_k, descale_v, descale_s,
scale_s, scale_o, sdpa_options);
std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride);
amax_o->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT);
amax_s->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT);
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1});
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q
std::shared_ptr<fe::graph::Tensor_attributes>, // K
std::shared_ptr<fe::graph::Tensor_attributes>, // V
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_q
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_k
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_v
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_s
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_s
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_o
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes>, // O
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_s
std::shared_ptr<fe::graph::Tensor_attributes> > // amax_o
key_tensors_tuple = std::make_tuple(Q, K, V, descale_q, descale_k, descale_v,
descale_s, scale_s, scale_o, attn_scale, O, amax_s, amax_o);
auto Stats_tuple = std::make_tuple(Stats);
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto padding_tuple = is_padding ?
std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto dropout_tuple = is_dropout ?
std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr);
NVTE_CHECK_CUDNN_FE(mha_graph->validate());
NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A}));
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
auto return_tuple = std::tuple_cat(
std::make_tuple(mha_graph), key_tensors_tuple,
Stats_tuple, bias_tuple, padding_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
auto [mha_graph, Q, K, V, descale_q, descale_k, descale_v, descale_s,
scale_s, scale_o, attn_scale, O, amax_s, amax_o, Stats,
bias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph(
sdpa_fp8_fprop_cache, descriptor);
auto plan_workspace_size = mha_graph->get_workspace_size();
// Exit to request upper level API to allocate memory if needed
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
if (workspace == nullptr) {
*workspace_size = plan_workspace_size + actual_seqlen_workspace_size;
return;
}
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
// Build variant pack
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{Q, devPtrQ},
{K, devPtrK},
{V, devPtrV},
{descale_q, devPtrDescaleQ},
{descale_k, devPtrDescaleK},
{descale_v, devPtrDescaleV},
{descale_s, devPtrDescaleS},
{scale_s, devPtrScaleS},
{scale_o, devPtrScaleO},
{attn_scale, &scaling_factor},
{O, devPtrO},
{amax_s, devPtrAmaxS},
{amax_o, devPtrAmaxO},
{Stats, devPtrM}};
// if (is_bias) {
// variant_pack[bias] = devPtrBias;
// }
// if (is_padding) {
// constexpr size_t nthreads_per_block = 128;
// const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
// void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
// void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ)
// + b * sizeof(int32_t);
// cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
// b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
// static_cast<const int32_t *>(devPtrCuSeqlensKV),
// static_cast<int32_t *>(devActualSeqlenQ),
// static_cast<int32_t *>(devActualSeqlenKV));
// variant_pack[seq_q] = devActualSeqlenQ;
// variant_pack[seq_kv] = devActualSeqlenKV;
// }
// if (is_dropout) {
// variant_pack[dropout_seed] = devPtrDropoutSeed;
// variant_pack[dropout_offset] = devPtrDropoutOffset;
// }
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
}
}
// fused attention BWD FP8 with FE 1.0+
void fused_attn_fp8_bwd_impl_v1(int64_t b, int64_t h, int64_t hg,
int64_t s_q, int64_t s_kv, int64_t d,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
void* devPtrQ, void* devPtrK, void* devPtrV,
void* devPtrM, void* devPtrZInv,
void* devPtrO, void* devPtrdO,
void* devPtrdQ, void* devPtrdK, void* devPtrdV,
void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV,
void* devPtrDescaleO, void* devPtrDescaledO,
void* devPtrDescaleS, void* devPtrDescaledP,
void* devPtrScaleS, void* devPtrScaledP,
void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV,
void* devPtrAmaxdP,
void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV,
void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV,
void* devPtrDropoutSeed, void* devPtrDropoutOffset,
cudnn_frontend::DataType_t fwd_tensor_type,
cudnn_frontend::DataType_t bwd_tensor_type,
void* workspace,
size_t* workspace_size,
cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f);
auto bias_b = b;
auto bias_h = h;
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
NVTE_CHECK(~is_padding,
"FP8 fused attention does not support padding/padding_causal mask yet!");
NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!");
try {
FADescriptor_v1 descriptor{b, h,
hg, s_q,
s_kv, d,
bias_b, bias_h,
scaling_factor, true,
dropout_probability, layout,
bias_type, mask_type,
fwd_tensor_type, bwd_tensor_type};
namespace fe = cudnn_frontend;
using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>,
std::shared_ptr<fe::graph::Tensor_attributes>, // q
std::shared_ptr<fe::graph::Tensor_attributes>, // k
std::shared_ptr<fe::graph::Tensor_attributes>, // v
std::shared_ptr<fe::graph::Tensor_attributes>, // o
std::shared_ptr<fe::graph::Tensor_attributes>, // stats
std::shared_ptr<fe::graph::Tensor_attributes>, // dO
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_q
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_k
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_v
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_o
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_dO
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_s
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_dP
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_dQ
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_dK
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_dV
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_s
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_dP
std::shared_ptr<fe::graph::Tensor_attributes>, // dQ
std::shared_ptr<fe::graph::Tensor_attributes>, // dK
std::shared_ptr<fe::graph::Tensor_attributes>, // dV
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_dQ
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_dK
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_dV
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_dP
std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // dBias
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // dropout_seed
std::shared_ptr<fe::graph::Tensor_attributes> >; // dropout_offset
using CacheType = std::map<FADescriptor_v1, graph_and_tensors>;
static thread_local CacheType sdpa_fp8_bprop_cache;
// Get plan from cache if cache is available, otherwise create one
auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor)
-> graph_and_tensors {
// if hit, return
auto it = cache.find(descriptor);
if (it != cache.end()) {
auto graph = it->second;
return graph;
}
// otherwise, build the op_graph and the plan. Then update cache
auto mha_graph = std::make_shared<fe::graph::Graph>();
mha_graph->set_io_data_type(fwd_tensor_type)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
std::shared_ptr<fe::graph::Tensor_attributes> q, k, v, o, dO, stats, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> descale_q, descale_k, descale_v;
std::shared_ptr<fe::graph::Tensor_attributes> descale_s, descale_o;
std::shared_ptr<fe::graph::Tensor_attributes> descale_dP, descale_dO;
std::shared_ptr<fe::graph::Tensor_attributes> scale_s, scale_dP;
std::shared_ptr<fe::graph::Tensor_attributes> scale_dQ, scale_dK, scale_dV;
std::shared_ptr<fe::graph::Tensor_attributes> bias, dBias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4);
std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d})
.set_stride(q_stride));
k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride));
v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride));
o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
.set_dim({b, h, s_q, d})
.set_stride(o_stride));
dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO")
.set_dim({b, h, s_q, d})
.set_stride(o_stride));
stats = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("stats")
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
descale_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Descale_q")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
descale_k = mha_graph->tensor_like(descale_q, "Descale_q");
descale_v = mha_graph->tensor_like(descale_q, "Descale_V");
descale_s = mha_graph->tensor_like(descale_q, "Descale_S");
descale_o = mha_graph->tensor_like(descale_q, "Descale_O");
descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP");
descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO");
scale_s = mha_graph->tensor_like(descale_q, "Scale_S");
scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP");
scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ");
scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK");
scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV");
fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options;
sdpa_backward_options = fe::graph::SDPA_fp8_backward_attributes()
.set_name("sdpa_fp8_backward")
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
// sdpa_backward_options.set_alibi_mask(is_alibi);
// if (is_bias) {
// bias = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("bias")
// .set_dim({bias_b, bias_h, s_q, s_kv})
// .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
// dBias = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("dBias")
// .set_dim({bias_b, bias_h, s_q, s_kv})
// .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
// sdpa_backward_options.set_bias(bias);
// // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s]
// // are not supported for dbias calculation but they are
// // supported for forward bias calculation
// if ((bias_b == 1) && (bias_h == h)) {
// sdpa_backward_options.set_dbias(dBias);
// }
// }
// if (is_padding) {
// seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("seq_q")
// .set_dim({b, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT32));
// seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("seq_kv")
// .set_dim({b, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT32));
// sdpa_backward_options.set_padding_mask(is_padding)
// .set_seq_len_q(seq_q)
// .set_seq_len_kv(seq_kv);
// }
// if (is_dropout) {
// dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("Seed")
// .set_dim({1, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT64));
// dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("Offset")
// .set_dim({1, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT64));
// sdpa_backward_options.set_dropout(
// dropout_probability, dropout_seed, dropout_offset);
// }
auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward(
q, k, v, o, dO, stats,
descale_q, descale_k, descale_v,
descale_o, descale_dO, descale_s, descale_dP,
scale_s, scale_dQ, scale_dK, scale_dV, scale_dP,
sdpa_backward_options);
dQ->set_output(true)
.set_dim({b, h, s_q, d})
.set_stride(q_stride);
dK->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(k_stride);
dV->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(v_stride);
amax_dQ->set_output(true)
.set_dim({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
amax_dK->set_output(true)
.set_dim({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
amax_dV->set_output(true)
.set_dim({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
amax_dP->set_output(true)
.set_dim({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
dO->set_data_type(bwd_tensor_type);
dQ->set_data_type(bwd_tensor_type);
dK->set_data_type(bwd_tensor_type);
dV->set_data_type(bwd_tensor_type);
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // q
std::shared_ptr<fe::graph::Tensor_attributes>, // k
std::shared_ptr<fe::graph::Tensor_attributes>, // v
std::shared_ptr<fe::graph::Tensor_attributes>, // o
std::shared_ptr<fe::graph::Tensor_attributes>, // stats
std::shared_ptr<fe::graph::Tensor_attributes>, // dO
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_q
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_k
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_v
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_o
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_dO
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_s
std::shared_ptr<fe::graph::Tensor_attributes>, // descale_dP
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_dQ
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_dK
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_dV
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_s
std::shared_ptr<fe::graph::Tensor_attributes>, // scale_dP
std::shared_ptr<fe::graph::Tensor_attributes>, // dQ
std::shared_ptr<fe::graph::Tensor_attributes>, // dK
std::shared_ptr<fe::graph::Tensor_attributes>, // dV
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_dQ
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_dK
std::shared_ptr<fe::graph::Tensor_attributes>, // amax_dV
std::shared_ptr<fe::graph::Tensor_attributes> > // amax_dP
key_tensors_tuple = std::make_tuple(
q, k, v, o, stats, dO, attn_scale,
descale_q, descale_k, descale_v,
descale_o, descale_dO, descale_s, descale_dP,
scale_s, scale_dQ, scale_dK, scale_dV, scale_dP,
dQ, dK, dV,
amax_dQ, amax_dK, amax_dV, amax_dP);
auto bias_tuple = is_bias ?
std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr);
auto padding_tuple = is_padding ?
std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto dropout_tuple = is_dropout ?
std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr);
NVTE_CHECK_CUDNN_FE(mha_graph->validate());
NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A}));
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
auto return_tuple = std::tuple_cat(
std::make_tuple(mha_graph), key_tensors_tuple,
bias_tuple, padding_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
auto [mha_graph, q, k, v, o, stats, dO, attn_scale,
descale_q, descale_k, descale_v,
descale_o, descale_dO, descale_s, descale_dP,
scale_s, scale_dQ, scale_dK, scale_dV, scale_dP,
dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP,
bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph(
sdpa_fp8_bprop_cache, descriptor);
auto plan_workspace_size = mha_graph->get_workspace_size();
// Exit to request upper level API to allocate memory if needed
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
if (workspace == nullptr) {
*workspace_size = plan_workspace_size + actual_seqlen_workspace_size;
return;
}
// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
// build variant pack
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{q, devPtrQ},
{k, devPtrK},
{v, devPtrV},
{o, devPtrO},
{stats, devPtrM},
{dO, devPtrdO},
{attn_scale, &scaling_factor},
{descale_q, devPtrDescaleQ},
{descale_k, devPtrDescaleK},
{descale_v, devPtrDescaleV},
{descale_o, devPtrDescaleO},
{descale_dO, devPtrDescaledO},
{descale_s, devPtrDescaleS},
{descale_dP, devPtrDescaledP},
{scale_s, devPtrScaleS},
{scale_dQ, devPtrScaledQ},
{scale_dK, devPtrScaledK},
{scale_dV, devPtrScaledV},
{scale_dP, devPtrScaledP},
{dQ, devPtrdQ},
{dK, devPtrdK},
{dV, devPtrdV},
{amax_dQ, devPtrAmaxdQ},
{amax_dK, devPtrAmaxdK},
{amax_dV, devPtrAmaxdV},
{amax_dP, devPtrAmaxdP},
};
// if (is_bias) {
// variant_pack[bias] = devPtrBias;
// if ((bias_b == 1) && (bias_h == h)) {
// variant_pack[dBias] = devPtrdBias;
// } else {
// variant_pack[dBias] = nullptr;
// }
// }
// if (is_padding) {
// constexpr size_t nthreads_per_block = 128;
// const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
// void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
// void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ)
// + b * sizeof(int32_t);
// cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
// b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
// static_cast<const int32_t *>(devPtrCuSeqlensKV),
// static_cast<int32_t *>(devActualSeqlenQ),
// static_cast<int32_t *>(devActualSeqlenKV));
// variant_pack[seq_q] = devActualSeqlenQ;
// variant_pack[seq_kv] = devActualSeqlenKV;
// }
// if (is_dropout) {
// variant_pack[dropout_seed] = devPtrDropoutSeed;
// variant_pack[dropout_offset] = devPtrDropoutOffset;
// }
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
}
}
#endif #endif
} // namespace fused_attn } // namespace fused_attn
...@@ -1853,9 +2555,10 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in ...@@ -1853,9 +2555,10 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV // fused attention FWD FP8 with packed QKV
void fused_attn_fp8_fwd_qkvpacked( void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t h, size_t max_seqlen, size_t d, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim,
bool is_training, float attn_scale, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_QKV,
Tensor *input_output_S, Tensor *input_output_S,
Tensor *output_O, Tensor *output_O,
...@@ -1866,11 +2569,18 @@ void fused_attn_fp8_fwd_qkvpacked( ...@@ -1866,11 +2569,18 @@ void fused_attn_fp8_fwd_qkvpacked(
cudaStream_t stream, cudaStream_t stream,
cudnnHandle_t handle) { cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
// QKV shape is [total_seqs, 3, h, d] const DType QKV_type = input_QKV->data.dtype;
void* devPtrQKV = input_QKV->data.dptr; void* devPtrQKV = input_QKV->data.dptr;
void* devPtrQ = reinterpret_cast<void *>(devPtrQKV); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
void* devPtrK = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrQKV) + h * d); size_t stride = 0;
void* devPtrV = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrQKV) + 2 * h * d); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = typeToSize(QKV_type) * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = typeToSize(QKV_type) * head_dim;
}
void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void* devPtrDescaleQ = input_QKV->scale_inv.dptr; void* devPtrDescaleQ = input_QKV->scale_inv.dptr;
void* devPtrDescaleK = input_QKV->scale_inv.dptr; void* devPtrDescaleK = input_QKV->scale_inv.dptr;
void* devPtrDescaleV = input_QKV->scale_inv.dptr; void* devPtrDescaleV = input_QKV->scale_inv.dptr;
...@@ -1882,21 +2592,19 @@ void fused_attn_fp8_fwd_qkvpacked( ...@@ -1882,21 +2592,19 @@ void fused_attn_fp8_fwd_qkvpacked(
void* devPtrM = nullptr; void* devPtrM = nullptr;
void* devPtrZInv = nullptr; void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
if (is_training) {
Aux_CTX_Tensors->size = 3; Aux_CTX_Tensors->size = 3;
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]); Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]); Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr; output_M->data.dptr = nullptr;
output_M->data.shape = {b, h, max_seqlen, 1}; output_M->data.shape = {batch, num_attn_heads, max_seqlen, 1};
output_M->data.dtype = DType::kFloat32; output_M->data.dtype = DType::kFloat32;
output_ZInv->data.dptr = nullptr; output_ZInv->data.dptr = nullptr;
output_ZInv->data.shape = {b, h, max_seqlen, 1}; output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen, 1};
output_ZInv->data.dtype = DType::kFloat32; output_ZInv->data.dtype = DType::kFloat32;
output_rng_state->data.dptr = nullptr; output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
}
} else if (Aux_CTX_Tensors->size == 3) { } else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]); Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
...@@ -1919,11 +2627,27 @@ void fused_attn_fp8_fwd_qkvpacked( ...@@ -1919,11 +2627,27 @@ void fused_attn_fp8_fwd_qkvpacked(
void* devPtrDropoutOffset = reinterpret_cast<void *>( void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1); reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0; size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_fwd_impl_v1(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleS, devPtrScaleS, devPtrScaleO,
devPtrAmaxO, devPtrAmaxS,
devPtrcuSeqlens, devPtrcuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_fwd_impl( fused_attn::fused_attn_fp8_fwd_impl(
b, h, max_seqlen, max_seqlen, d, batch, num_attn_heads, max_seqlen, max_seqlen, head_dim,
is_training, attn_scale, p_dropout, qkv_layout, is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv, devPtrM, devPtrZInv,
...@@ -1935,6 +2659,9 @@ void fused_attn_fp8_fwd_qkvpacked( ...@@ -1935,6 +2659,9 @@ void fused_attn_fp8_fwd_qkvpacked(
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type), get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle); workspace->data.dptr, &workspace_size, stream, handle);
} else {
NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n");
}
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1950,8 +2677,9 @@ void fused_attn_fp8_fwd_qkvpacked( ...@@ -1950,8 +2677,9 @@ void fused_attn_fp8_fwd_qkvpacked(
} }
// fused attention BWD FP8 with packed QKV // fused attention BWD FP8 with packed QKV
void fused_attn_fp8_bwd_qkvpacked( void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t h, size_t max_seqlen, size_t d, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_QKV,
const Tensor *input_O, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_dO,
...@@ -1966,11 +2694,19 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -1966,11 +2694,19 @@ void fused_attn_fp8_bwd_qkvpacked(
cudaStream_t stream, cudaStream_t stream,
cudnnHandle_t handle) { cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
// QKV shape is [total_seqs, 3, h, d] const DType QKV_type = input_QKV->data.dtype;
const DType dQKV_type = output_dQKV->data.dtype;
void* devPtrQKV = input_QKV->data.dptr; void* devPtrQKV = input_QKV->data.dptr;
void* devPtrQ = reinterpret_cast<void *>(devPtrQKV); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
void* devPtrK = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrQKV) + h * d); size_t stride = 0;
void* devPtrV = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrQKV) + 2 * h * d); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = typeToSize(QKV_type) * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = typeToSize(QKV_type) * head_dim;
}
void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void* devPtrDescaleQ = input_QKV->scale_inv.dptr; void* devPtrDescaleQ = input_QKV->scale_inv.dptr;
void* devPtrDescaleK = input_QKV->scale_inv.dptr; void* devPtrDescaleK = input_QKV->scale_inv.dptr;
void* devPtrDescaleV = input_QKV->scale_inv.dptr; void* devPtrDescaleV = input_QKV->scale_inv.dptr;
...@@ -1985,15 +2721,14 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -1985,15 +2721,14 @@ void fused_attn_fp8_bwd_qkvpacked(
void* devPtrScaleS = input_S->scale.dptr; void* devPtrScaleS = input_S->scale.dptr;
void* devPtrDescaleS = input_S->scale_inv.dptr; void* devPtrDescaleS = input_S->scale_inv.dptr;
void* devPtrAmaxdS = input_output_dP->amax.dptr; void* devPtrAmaxdP = input_output_dP->amax.dptr;
void* devPtrScaledS = input_output_dP->scale.dptr; void* devPtrScaledP = input_output_dP->scale.dptr;
void* devPtrDescaledS = input_output_dP->scale_inv.dptr; void* devPtrDescaledP = input_output_dP->scale_inv.dptr;
// dQKV shape is [total_seqs, 3, h, d] void *devPtrdQKV = output_dQKV->data.dptr;
void* devPtrdQKV = output_dQKV->data.dptr; void *devPtrdQ = devPtrdQKV;
void* devPtrdQ = reinterpret_cast<void *>(devPtrdQKV); void *devPtrdK = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + stride);
void* devPtrdK = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrdQKV) + h * d); void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + 2 * stride);
void* devPtrdV = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrdQKV) + 2 * h * d);
void* devPtrAmaxdQ = output_dQKV->amax.dptr; void* devPtrAmaxdQ = output_dQKV->amax.dptr;
void* devPtrAmaxdK = output_dQKV->amax.dptr; void* devPtrAmaxdK = output_dQKV->amax.dptr;
void* devPtrAmaxdV = output_dQKV->amax.dptr; void* devPtrAmaxdV = output_dQKV->amax.dptr;
...@@ -2008,11 +2743,33 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -2008,11 +2743,33 @@ void fused_attn_fp8_bwd_qkvpacked(
void* devPtrDropoutOffset = reinterpret_cast<void *>( void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1); reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0; size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_bwd_impl_v1(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO, devPtrdO,
devPtrdQ, devPtrdK, devPtrdV,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleO, devPtrDescaledO,
devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledP,
devPtrScaledQ, devPtrScaledK, devPtrScaledV,
devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV,
devPtrcuSeqlens, devPtrcuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(dQKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_bwd_impl( fused_attn::fused_attn_fp8_bwd_impl(
b, h, max_seqlen, max_seqlen, d, batch, num_attn_heads, max_seqlen, max_seqlen, head_dim,
attn_scale, p_dropout, qkv_layout, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv, devPtrM, devPtrZInv,
...@@ -2020,15 +2777,278 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -2020,15 +2777,278 @@ void fused_attn_fp8_bwd_qkvpacked(
devPtrdQ, devPtrdK, devPtrdV, devPtrdQ, devPtrdK, devPtrdV,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleO, devPtrDescaledO, devPtrDescaleO, devPtrDescaledO,
devPtrDescaleS, devPtrDescaledS, devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledS, devPtrScaleS, devPtrScaledP,
devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrScaledQ, devPtrScaledK, devPtrScaledV,
devPtrAmaxdS, devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV,
devPtrcuSeqlens, devPtrcuSeqlens, devPtrcuSeqlens, devPtrcuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type), get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle); workspace->data.dptr, &workspace_size, stream, handle);
} else {
NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n");
}
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = { workspace_size };
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = { 1 };
workspace->data.dtype = DType::kByte;
return;
}
}
// fused attention FWD FP8 with packed KV
void fused_attn_fp8_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q,
const Tensor *input_KV,
Tensor *input_output_S,
Tensor *output_O,
NVTETensorPack* Aux_CTX_Tensors,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
void* devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = typeToSize(QKV_type) * head_dim;
}
void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
void* devPtrDescaleQ = input_Q->scale_inv.dptr;
void* devPtrDescaleK = input_KV->scale_inv.dptr;
void* devPtrDescaleV = input_KV->scale_inv.dptr;
void* devPtrO = output_O->data.dptr;
void* devPtrAmaxO = output_O->amax.dptr;
void* devPtrScaleO = output_O->scale.dptr;
void* devPtrM = nullptr;
void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 3;
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr;
output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_M->data.dtype = DType::kFloat32;
output_ZInv->data.dptr = nullptr;
output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_ZInv->data.dtype = DType::kFloat32;
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void* devPtrAmaxS = input_output_S->amax.dptr;
void* devPtrScaleS = input_output_S->scale.dptr;
void* devPtrDescaleS = input_output_S->scale_inv.dptr;
void* devPtrcuSeqlensQ = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens_q->data.dptr));
void* devPtrcuSeqlensKV = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens_kv->data.dptr));
void* devPtrDropoutSeed = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr));
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_fwd_impl_v1(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleS, devPtrScaleS, devPtrScaleO,
devPtrAmaxO, devPtrAmaxS,
devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_fwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim,
is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleS, devPtrScaleS, devPtrScaleO,
devPtrAmaxO, devPtrAmaxS,
devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else {
NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n");
}
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = { workspace_size };
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = { 1 };
workspace->data.dtype = DType::kByte;
return;
}
}
// fused attention BWD FP8 with packed KV
void fused_attn_fp8_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q,
const Tensor *input_KV,
const Tensor *input_O,
const Tensor *input_dO,
const Tensor *input_M,
const Tensor *input_ZInv,
const Tensor *input_S,
Tensor *input_output_dP,
const Tensor *output_dQ,
const Tensor *output_dKV,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
const DType dQKV_type = output_dQ->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = typeToSize(QKV_type) * head_dim;
}
void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
void* devPtrDescaleQ = input_Q->scale_inv.dptr;
void* devPtrDescaleK = input_KV->scale_inv.dptr;
void* devPtrDescaleV = input_KV->scale_inv.dptr;
void* devPtrO = input_O->data.dptr;
void* devPtrDescaleO = input_O->scale_inv.dptr;
void* devPtrdO = input_dO->data.dptr;
void* devPtrDescaledO = input_dO->scale_inv.dptr;
void* devPtrM = input_M->data.dptr;
void* devPtrZInv = input_ZInv->data.dptr;
void* devPtrScaleS = input_S->scale.dptr;
void* devPtrDescaleS = input_S->scale_inv.dptr;
void* devPtrAmaxdP = input_output_dP->amax.dptr;
void* devPtrScaledP = input_output_dP->scale.dptr;
void* devPtrDescaledP = input_output_dP->scale_inv.dptr;
void *devPtrdQ = output_dQ->data.dptr;
void *devPtrdKV = output_dKV->data.dptr;
void *devPtrdK = devPtrdKV;
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdKV) + stride);
void* devPtrAmaxdQ = output_dQ->amax.dptr;
void* devPtrAmaxdK = output_dKV->amax.dptr;
void* devPtrAmaxdV = output_dKV->amax.dptr;
void* devPtrScaledQ = output_dQ->scale.dptr;
void* devPtrScaledK = output_dKV->scale.dptr;
void* devPtrScaledV = output_dKV->scale.dptr;
void* devPtrcuSeqlensQ = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens_q->data.dptr));
void* devPtrcuSeqlensKV = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens_kv->data.dptr));
void* devPtrDropoutSeed = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr));
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_bwd_impl_v1(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO, devPtrdO,
devPtrdQ, devPtrdK, devPtrdV,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleO, devPtrDescaledO,
devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledP,
devPtrScaledQ, devPtrScaledK, devPtrScaledV,
devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV,
devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(dQKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_bwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim,
attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO, devPtrdO,
devPtrdQ, devPtrdK, devPtrdV,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleO, devPtrDescaledO,
devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledP,
devPtrScaledQ, devPtrScaledK, devPtrScaledV,
devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV,
devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else {
NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n");
}
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -2044,9 +3064,11 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -2044,9 +3064,11 @@ void fused_attn_fp8_bwd_qkvpacked(
} }
// fused attention FWD FP8 with separate Q, K, V // fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd( void fused_attn_fp8_fwd(
size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_V,
...@@ -2074,21 +3096,19 @@ void fused_attn_fp8_fwd( ...@@ -2074,21 +3096,19 @@ void fused_attn_fp8_fwd(
void* devPtrM = nullptr; void* devPtrM = nullptr;
void* devPtrZInv = nullptr; void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
if (is_training) {
Aux_CTX_Tensors->size = 3; Aux_CTX_Tensors->size = 3;
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]); Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]); Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr; output_M->data.dptr = nullptr;
output_M->data.shape = {b, h, max_seqlen_q, 1}; output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_M->data.dtype = DType::kFloat32; output_M->data.dtype = DType::kFloat32;
output_ZInv->data.dptr = nullptr; output_ZInv->data.dptr = nullptr;
output_ZInv->data.shape = {b, h, max_seqlen_q, 1}; output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_ZInv->data.dtype = DType::kFloat32; output_ZInv->data.dtype = DType::kFloat32;
output_rng_state->data.dptr = nullptr; output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
}
} else if (Aux_CTX_Tensors->size == 3) { } else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]); Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
...@@ -2116,8 +3136,25 @@ void fused_attn_fp8_fwd( ...@@ -2116,8 +3136,25 @@ void fused_attn_fp8_fwd(
const DType QKV_type = input_Q->data.dtype; const DType QKV_type = input_Q->data.dtype;
size_t workspace_size = 0; size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_fwd_impl_v1(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleS, devPtrScaleS, devPtrScaleO,
devPtrAmaxO, devPtrAmaxS,
devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_fwd_impl( fused_attn::fused_attn_fp8_fwd_impl(
b, h, max_seqlen_q, max_seqlen_kv, d, batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim,
is_training, attn_scale, p_dropout, qkv_layout, is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv, devPtrM, devPtrZInv,
...@@ -2129,6 +3166,9 @@ void fused_attn_fp8_fwd( ...@@ -2129,6 +3166,9 @@ void fused_attn_fp8_fwd(
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type), get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle); workspace->data.dptr, &workspace_size, stream, handle);
} else {
NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n");
}
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -2144,8 +3184,10 @@ void fused_attn_fp8_fwd( ...@@ -2144,8 +3184,10 @@ void fused_attn_fp8_fwd(
} }
// fused attention BWD FP8 with separate Q, K, V // fused attention BWD FP8 with separate Q, K, V
void fused_attn_fp8_bwd( void fused_attn_fp8_bwd(
size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_V,
...@@ -2182,9 +3224,9 @@ void fused_attn_fp8_bwd( ...@@ -2182,9 +3224,9 @@ void fused_attn_fp8_bwd(
void* devPtrScaleS = input_S->scale.dptr; void* devPtrScaleS = input_S->scale.dptr;
void* devPtrDescaleS = input_S->scale_inv.dptr; void* devPtrDescaleS = input_S->scale_inv.dptr;
void* devPtrAmaxdS = input_output_dP->amax.dptr; void* devPtrAmaxdP = input_output_dP->amax.dptr;
void* devPtrScaledS = input_output_dP->scale.dptr; void* devPtrScaledP = input_output_dP->scale.dptr;
void* devPtrDescaledS = input_output_dP->scale_inv.dptr; void* devPtrDescaledP = input_output_dP->scale_inv.dptr;
void* devPtrdQ = output_dQ->data.dptr; void* devPtrdQ = output_dQ->data.dptr;
void* devPtrdK = output_dK->data.dptr; void* devPtrdK = output_dK->data.dptr;
...@@ -2206,10 +3248,34 @@ void fused_attn_fp8_bwd( ...@@ -2206,10 +3248,34 @@ void fused_attn_fp8_bwd(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1); reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_Q->data.dtype; const DType QKV_type = input_Q->data.dtype;
const DType dQKV_type = output_dQ->data.dtype;
size_t workspace_size = 0; size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_bwd_impl_v1(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO, devPtrdO,
devPtrdQ, devPtrdK, devPtrdV,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleO, devPtrDescaledO,
devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledP,
devPtrScaledQ, devPtrScaledK, devPtrScaledV,
devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV,
devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(dQKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_bwd_impl( fused_attn::fused_attn_fp8_bwd_impl(
b, h, max_seqlen_q, max_seqlen_kv, d, batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim,
attn_scale, p_dropout, qkv_layout, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv, devPtrM, devPtrZInv,
...@@ -2217,15 +3283,18 @@ void fused_attn_fp8_bwd( ...@@ -2217,15 +3283,18 @@ void fused_attn_fp8_bwd(
devPtrdQ, devPtrdK, devPtrdV, devPtrdQ, devPtrdK, devPtrdV,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleO, devPtrDescaledO, devPtrDescaleO, devPtrDescaledO,
devPtrDescaleS, devPtrDescaledS, devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledS, devPtrScaleS, devPtrScaledP,
devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrScaledQ, devPtrScaledK, devPtrScaledV,
devPtrAmaxdS, devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV,
devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type), get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle); workspace->data.dptr, &workspace_size, stream, handle);
} else {
NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n");
}
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
......
...@@ -14,9 +14,10 @@ namespace transformer_engine { ...@@ -14,9 +14,10 @@ namespace transformer_engine {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV // fused attention FWD FP8 with packed QKV
void fused_attn_fp8_fwd_qkvpacked( void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t h, size_t max_seqlen, size_t d, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim,
bool is_training, float attn_scale, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_QKV,
Tensor *input_output_S, Tensor *input_output_S,
Tensor *output_O, Tensor *output_O,
...@@ -29,8 +30,9 @@ void fused_attn_fp8_fwd_qkvpacked( ...@@ -29,8 +30,9 @@ void fused_attn_fp8_fwd_qkvpacked(
// fused attention BWD FP8 with packed QKV // fused attention BWD FP8 with packed QKV
void fused_attn_fp8_bwd_qkvpacked( void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t h, size_t max_seqlen, size_t d, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_QKV,
const Tensor *input_O, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_dO,
...@@ -45,11 +47,55 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -45,11 +47,55 @@ void fused_attn_fp8_bwd_qkvpacked(
cudaStream_t stream, cudaStream_t stream,
cudnnHandle_t handle); cudnnHandle_t handle);
// fused attention FWD FP8 with packed KV
void fused_attn_fp8_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q,
const Tensor *input_KV,
Tensor *input_output_S,
Tensor *output_O,
NVTETensorPack* Aux_CTX_Tensors,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle);
// fused attention BWD FP8 with packed KV
void fused_attn_fp8_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q,
const Tensor *input_KV,
const Tensor *input_O,
const Tensor *input_dO,
const Tensor *input_M,
const Tensor *input_ZInv,
const Tensor *input_S,
Tensor *input_output_dP,
const Tensor *output_dQ,
const Tensor *output_dKV,
const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle);
// fused attention FWD FP8 with separate Q, K, V // fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd( void fused_attn_fp8_fwd(
size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
Tensor *input_output_S, Tensor *input_output_S,
Tensor *output_O, Tensor *output_O,
...@@ -63,8 +109,10 @@ void fused_attn_fp8_fwd( ...@@ -63,8 +109,10 @@ void fused_attn_fp8_fwd(
// fused attention BWD FP8 with separate Q, K, V // fused attention BWD FP8 with separate Q, K, V
void fused_attn_fp8_bwd( void fused_attn_fp8_bwd(
size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_O, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_dO,
......
...@@ -111,19 +111,20 @@ struct FADescriptor_v1 { ...@@ -111,19 +111,20 @@ struct FADescriptor_v1 {
NVTE_QKV_Layout layout; NVTE_QKV_Layout layout;
NVTE_Bias_Type bias_type; NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type; NVTE_Mask_Type mask_type;
cudnn_frontend::DataType_t tensor_type; cudnn_frontend::DataType_t fwd_tensor_type;
cudnn_frontend::DataType_t bwd_tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const { bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h, return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h,
attnScale, isTraining, dropoutProbability, attnScale, isTraining, dropoutProbability,
layout, mask_type, bias_type, tensor_type) layout, mask_type, bias_type, fwd_tensor_type, bwd_tensor_type)
< std::tie( < std::tie(
rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d, rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d,
rhs.bias_b, rhs.bias_h, rhs.bias_b, rhs.bias_h,
rhs.attnScale, rhs.isTraining, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.bias_type, rhs.mask_type, rhs.bias_type,
rhs.tensor_type); rhs.fwd_tensor_type, rhs.bwd_tensor_type);
} }
}; };
......
...@@ -96,7 +96,7 @@ class DelayedScaling: ...@@ -96,7 +96,7 @@ class DelayedScaling:
where `Tensor` is a framework tensor type. where `Tensor` is a framework tensor type.
override_linear_precision: Tuple(bool, bool, bool), default=(False, False, False) override_linear_precision: Tuple(bool, bool, bool), default=(False, False, False)
Whether or not the execute the `fprop`, `dgrad`, and `wgrad` Whether or not to execute the `fprop`, `dgrad`, and `wgrad`
GEMMs (respectively) in higher precision when using FP8. GEMMs (respectively) in higher precision when using FP8.
reduce_amax: bool, default = `True` reduce_amax: bool, default = `True`
By default, if `torch.distributed` is initialized, the `amax` value for FP8 By default, if `torch.distributed` is initialized, the `amax` value for FP8
...@@ -106,6 +106,20 @@ class DelayedScaling: ...@@ -106,6 +106,20 @@ class DelayedScaling:
GPU maintains local amaxes and scaling factors. To ensure results are GPU maintains local amaxes and scaling factors. To ensure results are
numerically identical across checkpointing boundaries in this case, all numerically identical across checkpointing boundaries in this case, all
ranks must checkpoint in order to store the local tensors. ranks must checkpoint in order to store the local tensors.
fp8_dpa: bool, default = `False`
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend.
fp8_mha: bool, default = `False`
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
operations mentioned above at the DPA boundaries. Currently only standard MHA modules
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
`fp8_mha = False, fp8_dpa = True`, a typical MHA module works as
`LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`.
When `fp8_mha = True, fp8_dpa = True`, it becomes
`LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`.
Notes Notes
----- -----
...@@ -116,6 +130,9 @@ class DelayedScaling: ...@@ -116,6 +130,9 @@ class DelayedScaling:
FP8_MAX = maximum_representable_value(fp8_format) FP8_MAX = maximum_representable_value(fp8_format)
new_scaling_factor = (FP8_MAX / amax) / (2 ^ margin) new_scaling_factor = (FP8_MAX / amax) / (2 ^ margin)
* `fp8_dpa` and `fp8_mha` are Beta features, and their API and functionality are
subject to change in future Transformer Engine releases.
""" """
margin: int = 0 margin: int = 0
...@@ -126,6 +143,8 @@ class DelayedScaling: ...@@ -126,6 +143,8 @@ class DelayedScaling:
override_linear_precision: _OverrideLinearPrecision = _OverrideLinearPrecision() override_linear_precision: _OverrideLinearPrecision = _OverrideLinearPrecision()
scaling_factor_compute_algo: Optional[Callable] = None scaling_factor_compute_algo: Optional[Callable] = None
reduce_amax: bool = True reduce_amax: bool = True
fp8_dpa: bool = False
fp8_mha: bool = False
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
......
...@@ -19,6 +19,10 @@ import torch ...@@ -19,6 +19,10 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import (
cast_to_fp8,
cast_from_fp8,
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import ( from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd_qkvpacked, fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked, fused_attn_bwd_qkvpacked,
...@@ -31,7 +35,10 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -31,7 +35,10 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
AttnMaskType, AttnMaskType,
FusedAttnBackend, FusedAttnBackend,
) )
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module import LayerNormLinear, Linear from transformer_engine.pytorch.module import LayerNormLinear, Linear
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
divide, divide,
attention_mask_func, attention_mask_func,
...@@ -74,6 +81,12 @@ if _flash_attn_version >= _flash_attn_version_required: ...@@ -74,6 +81,12 @@ if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward # pylint: disable=no-name-in-module,ungrouped-imports from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward # pylint: disable=no-name-in-module,ungrouped-imports
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
_alibi_cache = { _alibi_cache = {
...@@ -810,7 +823,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -810,7 +823,7 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_, dk_, dv_, _ = fused_attn_bwd( dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.max_seqlen_q, ctx.max_seqlen_k,
cu_seqlens_q, cu_seqlens_k, cu_seqlens_q, cu_seqlens_k,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], TE_DType[kv.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]], [softmax_lse, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale, attn_scale=ctx.softmax_scale,
...@@ -850,7 +863,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -850,7 +863,7 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_, dk_, dv_, _ = fused_attn_bwd( dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k//2, ctx.max_seqlen_q, ctx.max_seqlen_k//2,
cu_seqlens_q, cu_seqlens_k//2, cu_seqlens_q, cu_seqlens_k//2,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], TE_DType[kv.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]], [softmax_lse, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale, attn_scale=ctx.softmax_scale,
...@@ -890,7 +903,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -890,7 +903,7 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_, dk_, dv_, _ = fused_attn_bwd( dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q//2, ctx.max_seqlen_k, ctx.max_seqlen_q//2, ctx.max_seqlen_k,
cu_seqlens_q//2, cu_seqlens_k, cu_seqlens_q//2, cu_seqlens_k,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], TE_DType[kv.dtype],
[softmax_lse_, ctx.rng_states[cp_size-i-1]], [softmax_lse_, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale, attn_scale=ctx.softmax_scale,
...@@ -923,7 +936,7 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -923,7 +936,7 @@ class AttnFuncWithCP(torch.autograd.Function):
dq_, dk_, dv_, _ = fused_attn_bwd( dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.max_seqlen_q, ctx.max_seqlen_k,
cu_seqlens_q, cu_seqlens_k, cu_seqlens_q, cu_seqlens_k,
q, kv[0], kv[1], out, dout, TE_DType[q.dtype], q, kv[0], kv[1], out, dout, TE_DType[q.dtype], TE_DType[kv.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]], [softmax_lse, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale, attn_scale=ctx.softmax_scale,
...@@ -1246,6 +1259,14 @@ class _SplitAlongDim(torch.autograd.Function): ...@@ -1246,6 +1259,14 @@ class _SplitAlongDim(torch.autograd.Function):
) -> Tuple[torch.Tensor, ...]: ) -> Tuple[torch.Tensor, ...]:
ctx.split_dim = split_dim ctx.split_dim = split_dim
ctx.split_size_or_sections = split_size_or_sections ctx.split_size_or_sections = split_size_or_sections
if isinstance(mixed_x_layer, Float8Tensor):
return tuple(Float8Tensor.make_like(
mixed_x_layer,
data=x,
) for x in torch.split(
mixed_x_layer._data,
split_size_or_sections=split_size_or_sections,
dim=split_dim))
return torch.split(mixed_x_layer, split_size_or_sections, dim = split_dim) return torch.split(mixed_x_layer, split_size_or_sections, dim = split_dim)
@staticmethod @staticmethod
...@@ -1262,6 +1283,37 @@ class _SplitAlongDim(torch.autograd.Function): ...@@ -1262,6 +1283,37 @@ class _SplitAlongDim(torch.autograd.Function):
dims = len(grad_outputs[0].shape) dims = len(grad_outputs[0].shape)
split_dim = (ctx.split_dim + dims) % dims split_dim = (ctx.split_dim + dims) % dims
if isinstance(grad_outputs[0], Float8Tensor):
noop_ok = True
strides = grad_outputs[0].stride()
data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
shape = list(grad_outputs[0].shape)
for i, tensor in enumerate(grad_outputs):
shape_i = shape
shape_i[split_dim] = split_sizes[i]
offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim+1:])
if (tensor.stride() != strides or
list(tensor.shape) != shape_i or
tensor._data.untyped_storage().data_ptr() != data_ptr or
tensor.storage_offset() != offset_size):
noop_ok = False
break
if noop_ok:
ret = torch.Tensor().to(device=grad_outputs[0].device,
dtype=grad_outputs[0]._data.dtype)
new_shape = list(shape)
new_shape[split_dim] = sum(split_sizes)
ret.set_(grad_outputs[0]._data.untyped_storage(),
grad_outputs[0]._data.storage_offset(),
new_shape,
strides
)
return Float8Tensor.make_like(grad_outputs[0], data=ret), None, None
grad_outputs_data = [x._data for x in grad_outputs]
return Float8Tensor.make_like(
grad_outputs[0],
data=torch.cat(grad_outputs_data, dim = split_dim)), None, None
noop_ok = True noop_ok = True
strides = grad_outputs[0].stride() strides = grad_outputs[0].stride()
data_ptr = grad_outputs[0].untyped_storage().data_ptr() data_ptr = grad_outputs[0].untyped_storage().data_ptr()
...@@ -1276,7 +1328,6 @@ class _SplitAlongDim(torch.autograd.Function): ...@@ -1276,7 +1328,6 @@ class _SplitAlongDim(torch.autograd.Function):
tensor.storage_offset() != offset_size): tensor.storage_offset() != offset_size):
noop_ok = False noop_ok = False
break break
if noop_ok: if noop_ok:
ret = torch.Tensor().to(device=grad_outputs[0].device, ret = torch.Tensor().to(device=grad_outputs[0].device,
dtype=grad_outputs[0].dtype) dtype=grad_outputs[0].dtype)
...@@ -1848,6 +1899,35 @@ class FlashAttention(torch.nn.Module): ...@@ -1848,6 +1899,35 @@ class FlashAttention(torch.nn.Module):
return output return output
def _combine_tensors(
tensors: List[torch.Tensor],
dim: int,
) -> torch.Tensor:
"""Combine tensors along a particular dimension"""
num_tensors = len(tensors)
new_shape = list(tensors[0].shape)
new_shape.insert(dim, num_tensors)
new_stride = list(tensors[0].stride())
new_stride.insert(dim, int(new_stride[dim-1]/num_tensors))
if isinstance(tensors[0], Float8Tensor):
combined_tensor = torch.Tensor().to(
device=tensors[0].device, dtype=tensors[0]._data.dtype)
combined_tensor.set_(
tensors[0]._data.untyped_storage(),
tensors[0]._data.storage_offset(),
new_shape, new_stride)
combined_tensor = Float8Tensor.make_like(
tensors[0], data=combined_tensor)
else:
combined_tensor = torch.Tensor().to(
device=tensors[0].device, dtype=tensors[0].dtype)
combined_tensor.set_(
tensors[0].untyped_storage(),
tensors[0].storage_offset(),
new_shape, new_stride)
return combined_tensor
class FusedAttnFunc_qkvpacked(torch.autograd.Function): class FusedAttnFunc_qkvpacked(torch.autograd.Function):
"""Function for FusedAttention with packed QKV input""" """Function for FusedAttention with packed QKV input"""
...@@ -1855,15 +1935,83 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -1855,15 +1935,83 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale, def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale,
dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen, fused_attention_backend, use_FAv2_bwd): rng_gen, fused_attention_backend, use_FAv2_bwd,
out, aux_ctx_tensors = fused_attn_fwd_qkvpacked( fp8, fp8_meta, tp_size, tp_group):
if fp8:
if _NVTE_DEBUG:
print('[DotProductAttention]: using FP8 forward')
if fp8_meta["recipe"].fp8_mha:
assert (isinstance(qkv, Float8Tensor)), "qkv must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
fused_attention_backend = FusedAttnBackend["FP8"]
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_group = len(qkv_layout.split('_'))
assert (qkv_group == 1
), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, \
but found {qkv_layout}."
if fp8_meta["recipe"].fp8_mha:
qkv_fp8 = qkv._data
else:
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
qkv_fp8 = cast_to_fp8(qkv_c,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(qkv.shape)
out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
is_training, max_seqlen, cu_seqlens,
qkv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
fp8_meta["scaling_fwd"].scale[META_O],
fp8_meta["scaling_fwd"].amax_history[0][META_S],
fp8_meta["scaling_fwd"].amax_history[0][META_O],
attn_scale, dropout_p, fast_zero_fill, qkv_layout,
attn_bias_type, attn_mask_type, rng_gen)
if fp8_meta["recipe"].fp8_mha:
out_ret = Float8Tensor(data=out_fp8,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=META_O,
fp8_dtype=fp8_dtype_forward,
dtype=qkv.dtype,
)
else:
out_ret = cast_from_fp8(
out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
fp8_meta["scaling_fwd"], META_O,
fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
out_save = out_ret
if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
qkv = cast_from_fp8(qkv_c._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[qkv.dtype]).view(qkv.shape)
out_save = cast_from_fp8(
out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
fp8_meta["scaling_fwd"], META_O,
fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
fp8_tensors = (qkv_fp8, out_fp8,
fp8_meta["scaling_fwd"].scale.clone(),
fp8_meta["scaling_fwd"].scale_inv.clone())
else:
if _NVTE_DEBUG:
print('[DotProductAttention]: using non-FP8 forward')
out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype,
fused_attention_backend, attn_bias, fused_attention_backend, attn_bias,
None, None, None, None, None, None, None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen) rng_gen)
fp8_tensors = (None, None, None, None)
ctx.save_for_backward(qkv, out, cu_seqlens) out_save = out_ret
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens, *fp8_tensors)
ctx.fp8_meta = fp8_meta
ctx.tp_size = tp_size
ctx.tp_group = tp_group
ctx.aux_ctx_tensors = aux_ctx_tensors ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen = max_seqlen ctx.max_seqlen = max_seqlen
ctx.qkv_dtype = qkv_dtype ctx.qkv_dtype = qkv_dtype
...@@ -1873,15 +2021,23 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -1873,15 +2021,23 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.qkv_layout = qkv_layout ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.fused_attention_backend = fused_attention_backend ctx.fused_attention_backend = \
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
ctx.use_FAv2_bwd = use_FAv2_bwd ctx.use_FAv2_bwd = use_FAv2_bwd
return out return out_ret
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
if ctx.fp8_meta["recipe"].fp8_mha:
assert (isinstance(d_out, Float8Tensor)
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
d_out_f8tensor = d_out
d_out = d_out._data
d_out = d_out.contiguous() d_out = d_out.contiguous()
qkv, out, cu_seqlens = ctx.saved_tensors (qkv, out, cu_seqlens,
qkv_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous(): if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd: if ctx.use_FAv2_bwd:
...@@ -1898,11 +2054,63 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -1898,11 +2054,63 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
) )
dqkv = dqkv[..., :d_out.shape[-1]] dqkv = dqkv[..., :d_out.shape[-1]]
else: else:
with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"):
if ctx.fp8:
if _NVTE_DEBUG:
print('[DotProductAttention]: using FP8 backward')
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False)
if ctx.fp8_meta["recipe"].fp8_mha:
d_out_fp8 = d_out
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv
else:
d_out_fp8 = cast_to_fp8(
d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
).view(d_out.shape)
dqkv_fp8, *rest = fused_attn_bwd_qkvpacked(
ctx.max_seqlen, cu_seqlens,
qkv_fp8, out_fp8, d_out_fp8,
fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o,
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp
fwd_scales[META_S], # q_scale_s
ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp
ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
if ctx.fp8_meta["recipe"].fp8_mha:
dqkv = Float8Tensor(data=dqkv_fp8,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward,
dtype=d_out_f8tensor.dtype,
)
else:
dqkv_c_fp8 = dqkv_fp8.view(-1,
dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1])
dqkv = cast_from_fp8(dqkv_c_fp8,
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape)
else:
if _NVTE_DEBUG:
print('[DotProductAttention]: using non-FP8 backward')
if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(qkv.dtype)
dqkv, *rest = fused_attn_bwd_qkvpacked( dqkv, *rest = fused_attn_bwd_qkvpacked(
ctx.max_seqlen, cu_seqlens, qkv, out, d_out, ctx.max_seqlen, cu_seqlens, qkv, out, d_out,
ctx.qkv_dtype, ctx.aux_ctx_tensors, ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
...@@ -1923,16 +2131,90 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -1923,16 +2131,90 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
qkv_layout, attn_bias_type, attn_mask_type, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
rng_gen, fused_attention_backend, use_FAv2_bwd): use_FAv2_bwd, fp8, fp8_meta, tp_size, tp_group):
out, aux_ctx_tensors = fused_attn_fwd_kvpacked( if fp8:
if _NVTE_DEBUG:
print('[DotProductAttention]: using FP8 forward')
if fp8_meta["recipe"].fp8_mha:
assert (isinstance(q, Float8Tensor)
and isinstance(kv, Float8Tensor)), "q/kv must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
fused_attention_backend = FusedAttnBackend["FP8"]
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if fp8_meta["recipe"].fp8_mha:
q_fp8, kv_fp8 = q._data, kv._data
else:
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_group = len(qkv_layout.split('_'))
assert (qkv_group == 2
), f"qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, \
but found {qkv_layout}."
q_fp8 = cast_to_fp8(q,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(q.shape)
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv_fp8 = cast_to_fp8(kv_c,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(kv.shape)
out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q_fp8, kv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
fp8_meta["scaling_fwd"].scale[META_O],
fp8_meta["scaling_fwd"].amax_history[0][META_S],
fp8_meta["scaling_fwd"].amax_history[0][META_O],
attn_scale, dropout_p, fast_zero_fill, qkv_layout,
attn_bias_type, attn_mask_type, rng_gen)
if fp8_meta["recipe"].fp8_mha:
out_ret = Float8Tensor(data=out_fp8,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=META_O,
fp8_dtype=fp8_dtype_forward,
dtype=q.dtype,
)
else:
out_ret = cast_from_fp8(
out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
fp8_meta["scaling_fwd"], META_O,
fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
out_save = out_ret
if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q = cast_from_fp8(q._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape)
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv = cast_from_fp8(kv_c._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[kv.dtype]).view(kv.shape)
out_save = cast_from_fp8(
out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
fp8_meta["scaling_fwd"], META_O,
fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
fp8_tensors = (q_fp8, kv_fp8, out_fp8,
fp8_meta["scaling_fwd"].scale.clone(),
fp8_meta["scaling_fwd"].scale_inv.clone())
else:
if _NVTE_DEBUG:
print('[DotProductAttention]: using non-FP8 forward')
out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, qkv_dtype, fused_attention_backend, attn_bias, q, kv, qkv_dtype, fused_attention_backend, attn_bias,
None, None, None, None, None, None, None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen) rng_gen)
out_save = out_ret
ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv) fp8_tensors = (None, None, None, None, None)
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors)
ctx.fp8_meta = fp8_meta
ctx.tp_size = tp_size
ctx.tp_group = tp_group
ctx.aux_ctx_tensors = aux_ctx_tensors ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv ctx.max_seqlen_kv = max_seqlen_kv
...@@ -1943,15 +2225,23 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -1943,15 +2225,23 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.qkv_layout = qkv_layout ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.fused_attention_backend = fused_attention_backend ctx.fused_attention_backend = \
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
ctx.use_FAv2_bwd = use_FAv2_bwd ctx.use_FAv2_bwd = use_FAv2_bwd
return out return out_ret
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
if ctx.fp8_meta["recipe"].fp8_mha:
assert (isinstance(d_out, Float8Tensor)
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
d_out_f8tensor = d_out
d_out = d_out._data
d_out = d_out.contiguous() d_out = d_out.contiguous()
q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors (q, kv, out, cu_seqlens_q, cu_seqlens_kv,
q_fp8, kv_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous(): if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd: if ctx.use_FAv2_bwd:
...@@ -1970,12 +2260,75 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -1970,12 +2260,75 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
dq = dq[..., :d_out.shape[-1]] dq = dq[..., :d_out.shape[-1]]
dkv = dkv[..., :d_out.shape[-1]] dkv = dkv[..., :d_out.shape[-1]]
else: else:
with torch.cuda.nvtx.range("_FusedAttn_kvpacked"):
if ctx.fp8:
if _NVTE_DEBUG:
print('[DotProductAttention]: using FP8 backward')
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False)
if ctx.fp8_meta["recipe"].fp8_mha:
d_out_fp8 = d_out
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv
else:
d_out_fp8 = cast_to_fp8(
d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
).view(d_out.shape)
dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q_fp8, kv_fp8, out_fp8, d_out_fp8,
fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o,
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp
fwd_scales[META_S], # q_scale_s
ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp
ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
if ctx.fp8_meta["recipe"].fp8_mha:
dq = Float8Tensor(data=dq_fp8,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward,
dtype=d_out_f8tensor.dtype,
)
dkv = Float8Tensor(data=dkv_fp8,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward,
dtype=d_out_f8tensor.dtype,
)
else:
dq = cast_from_fp8(
dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape)
dkv_c_fp8 = dkv_fp8.view(-1,
dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1])
dkv = cast_from_fp8(dkv_c_fp8,
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape)
else:
if _NVTE_DEBUG:
print('[DotProductAttention]: using non-FP8 backward')
if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(q.dtype)
dq, dkv, *rest = fused_attn_bwd_kvpacked( dq, dkv, *rest = fused_attn_bwd_kvpacked(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, out, d_out, q, kv, out, d_out,
ctx.qkv_dtype, ctx.aux_ctx_tensors, ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
...@@ -1989,32 +2342,153 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -1989,32 +2342,153 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
None, None, None, None, None, None, None, None, None, None, None, None,
None, None, None, None, None, None) None, None, None, None, None, None)
class FusedAttnFunc(torch.autograd.Function): class FusedAttnFunc(torch.autograd.Function):
"""Function for FusedAttention with separate Q, K, V tensors""" """Function for FusedAttention with separate Q, K, V tensors"""
@staticmethod @staticmethod
def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
qkv_layout, attn_bias_type, attn_mask_type, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
rng_gen, fused_attention_backend, use_FAv2_bwd): use_FAv2_bwd, fp8, fp8_meta, tp_size, tp_group):
out, aux_ctx_tensors = fused_attn_fwd( if fp8:
if _NVTE_DEBUG:
print('[DotProductAttention]: using FP8 forward')
fused_attention_backend = FusedAttnBackend["FP8"]
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if fp8_meta["recipe"].fp8_mha:
assert (isinstance(q, Float8Tensor)
and isinstance(k, Float8Tensor)
and isinstance(v, Float8Tensor)), "q/k/v must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data
else:
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_group = len(qkv_layout.split('_'))
if qkv_group == 1:
dim = qkv_layout.find('3')
qkv = _combine_tensors([q,k,v], dim)
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
qkv_fp8 = cast_to_fp8(qkv_c,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(qkv.shape)
q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1,1,1])
q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]]
if qkv_group == 2:
q_fp8 = cast_to_fp8(q,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(q.shape)
dim = qkv_layout.split('_')[1].find('2')
kv = _combine_tensors([k,v], dim)
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv_fp8 = cast_to_fp8(kv_c,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(kv.shape)
k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1,1])
k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]]
if qkv_group == 3:
q_fp8 = cast_to_fp8(q,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(q.shape)
k_fp8 = cast_to_fp8(k,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(k.shape)
v_fp8 = cast_to_fp8(v,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward).view(v.shape)
out_fp8, aux_ctx_tensors = fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q_fp8, k_fp8, v_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
fp8_meta["scaling_fwd"].scale[META_O],
fp8_meta["scaling_fwd"].amax_history[0][META_S],
fp8_meta["scaling_fwd"].amax_history[0][META_O],
attn_scale, dropout_p, fast_zero_fill, qkv_layout,
attn_bias_type, attn_mask_type, rng_gen)
if fp8_meta["recipe"].fp8_mha:
out_ret = Float8Tensor(data=out_fp8,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=META_O,
fp8_dtype=fp8_dtype_forward,
dtype=q.dtype,
)
else:
out_ret = cast_from_fp8(
out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
fp8_meta["scaling_fwd"], META_O,
fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
out_save = out_ret
if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_group = len(qkv_layout.split('_'))
if qkv_group == 1:
dim = qkv_layout.find('3')
qkv = _combine_tensors([q,k,v], dim)
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
qkv_no_fp8 = cast_from_fp8(qkv_c._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[qkv.dtype]).view(qkv.shape)
q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1,1,1])
q, k, v = [x.squeeze(dim) for x in [q, k, v]]
if qkv_group == 2:
q = cast_from_fp8(q._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape)
dim = qkv_layout.split('_')[1].find('2')
kv = _combine_tensors([k,v], dim)
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv_no_fp8 = cast_from_fp8(kv_c._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[kv.dtype]).view(kv.shape)
k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1,1])
k, v = [x.squeeze(dim) for x in [k, v]]
if qkv_group == 3:
q = cast_from_fp8(q._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape)
k = cast_from_fp8(k._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[k.dtype]).view(k.shape)
v = cast_from_fp8(v._data,
fp8_meta["scaling_fwd"],
META_QKV, fp8_dtype_forward, TE_DType[v.dtype]).view(v.shape)
out_save = cast_from_fp8(
out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
fp8_meta["scaling_fwd"], META_O,
fp8_dtype_forward, qkv_dtype).view(out_fp8.shape)
fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8,
fp8_meta["scaling_fwd"].scale.clone(),
fp8_meta["scaling_fwd"].scale_inv.clone())
else:
if _NVTE_DEBUG:
print('[DotProductAttention]: using non-FP8 forward')
out_ret, aux_ctx_tensors = fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, k, v, qkv_dtype, fused_attention_backend, attn_bias, q, k, v, qkv_dtype, fused_attention_backend, attn_bias,
None, None, None, None, None, None, None, None, None, None, None,
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen) rng_gen)
out_save = out_ret
fp8_tensors = (None, None, None, None, None, None)
from .cpu_offload import CPUOffloadEnabled from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled: if CPUOffloadEnabled:
tensor_list = [q, k, v, out, cu_seqlens_q, cu_seqlens_kv] tensor_list = [q, k, v, out_save, cu_seqlens_q, cu_seqlens_kv]
qkv_layout = 'sbhd_sbhd_sbhd' qkv_layout = 'sbhd_sbhd_sbhd'
for tensor in tensor_list: for tensor in tensor_list:
if tensor is not None: if tensor is not None:
tensor.activation_offloading = True tensor.activation_offloading = True
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv) qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors)
ctx.fp8_meta = fp8_meta
ctx.tp_size = tp_size
ctx.tp_group = tp_group
ctx.aux_ctx_tensors = aux_ctx_tensors ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv ctx.max_seqlen_kv = max_seqlen_kv
...@@ -2025,15 +2499,23 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -2025,15 +2499,23 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout = qkv_layout ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.fused_attention_backend = fused_attention_backend ctx.fused_attention_backend = \
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
ctx.use_FAv2_bwd = use_FAv2_bwd ctx.use_FAv2_bwd = use_FAv2_bwd
return out return out_ret
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
if ctx.fp8_meta["recipe"].fp8_mha:
assert (isinstance(d_out, Float8Tensor)
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
d_out_f8tensor = d_out
d_out = d_out._data
d_out = d_out.contiguous() d_out = d_out.contiguous()
q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors (q, k, v, out, cu_seqlens_q, cu_seqlens_kv,
q_fp8, k_fp8, v_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors
if not ctx.aux_ctx_tensors[0].is_contiguous(): if not ctx.aux_ctx_tensors[0].is_contiguous():
ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()
if ctx.use_FAv2_bwd: if ctx.use_FAv2_bwd:
...@@ -2054,12 +2536,110 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -2054,12 +2536,110 @@ class FusedAttnFunc(torch.autograd.Function):
dk = dk[..., :d_out.shape[-1]] dk = dk[..., :d_out.shape[-1]]
dv = dv[..., :d_out.shape[-1]] dv = dv[..., :d_out.shape[-1]]
else: else:
with torch.cuda.nvtx.range("_FusedAttn"):
if ctx.fp8:
if _NVTE_DEBUG:
print('[DotProductAttention]: using FP8 backward')
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False)
if ctx.fp8_meta["recipe"].fp8_mha:
d_out_fp8 = d_out
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv
else:
d_out_fp8 = cast_to_fp8(
d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]),
ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
).view(d_out.shape)
dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q_fp8, k_fp8, v_fp8, out_fp8, d_out_fp8,
fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors,
ctx.fused_attention_backend,
fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o,
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp
fwd_scales[META_S], # q_scale_s
ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp
ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp
ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
if ctx.fp8_meta["recipe"].fp8_mha:
dq = Float8Tensor(data=dq_fp8,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward,
dtype=d_out_f8tensor.dtype,
)
dk = Float8Tensor(data=dk_fp8,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward,
dtype=d_out_f8tensor.dtype,
)
dv = Float8Tensor(data=dv_fp8,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward,
dtype=d_out_f8tensor.dtype,
)
else:
qkv_group = len(ctx.qkv_layout.split('_'))
if qkv_group == 1:
dim = ctx.qkv_layout.find('3')
dqkv_fp8 = _combine_tensors([dq_fp8,dk_fp8,dv_fp8], dim)
dqkv_c_fp8 = dqkv_fp8.view(-1,
dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1])
dqkv = cast_from_fp8(dqkv_c_fp8,
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape)
dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1,1,1])
dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]]
if qkv_group == 2:
dq = cast_from_fp8(
dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape)
dim = ctx.qkv_layout.split('_')[1].find('2')
dkv_fp8 = _combine_tensors([dk_fp8,dv_fp8], dim)
dkv_c_fp8 = dkv_fp8.view(-1,
dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1])
dkv = cast_from_fp8(dkv_c_fp8,
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape)
dk, dv = _SplitAlongDim.apply(dkv, dim, [1,1])
dk, dv = [x.squeeze(dim) for x in [dk, dv]]
if qkv_group == 3:
dq = cast_from_fp8(
dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]),
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape)
dk = cast_from_fp8(
dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]),
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dk_fp8.shape)
dv = cast_from_fp8(
dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]),
ctx.fp8_meta["scaling_bwd"], META_DQKV,
fp8_dtype_backward, ctx.qkv_dtype).view(dv_fp8.shape)
else:
if _NVTE_DEBUG:
print('[DotProductAttention]: using non-FP8 backward')
if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(q.dtype)
dq, dk, dv, *rest = fused_attn_bwd( dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, k, v, out, d_out, q, k, v, out, d_out,
ctx.qkv_dtype, ctx.aux_ctx_tensors, ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,
ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
...@@ -2074,7 +2654,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -2074,7 +2654,7 @@ class FusedAttnFunc(torch.autograd.Function):
None, None, None, None, None, None) None, None, None, None, None, None)
class FusedAttention(torch.nn.Module): class FusedAttention(TransformerEngineBaseModule):
"""Dot product attention, with multiple backends: """Dot product attention, with multiple backends:
1. FusedAttnBackend["F16_max512_seqlen"] 1. FusedAttnBackend["F16_max512_seqlen"]
...@@ -2110,6 +2690,8 @@ class FusedAttention(torch.nn.Module): ...@@ -2110,6 +2690,8 @@ class FusedAttention(torch.nn.Module):
attention_type: str = "self", attention_type: str = "self",
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
deterministic: bool = False, deterministic: bool = False,
tp_size: int = 1,
tp_group: Optional[dist_group_type] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -2136,6 +2718,15 @@ class FusedAttention(torch.nn.Module): ...@@ -2136,6 +2718,15 @@ class FusedAttention(torch.nn.Module):
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
self.tp_size = tp_size
self.tp_group = tp_group
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[Float8Tensor]:
"""Needs override."""
@no_torch_dynamo() @no_torch_dynamo()
def forward( def forward(
self, self,
...@@ -2157,6 +2748,7 @@ class FusedAttention(torch.nn.Module): ...@@ -2157,6 +2748,7 @@ class FusedAttention(torch.nn.Module):
cp_group: Optional[dist_group_type] = None, cp_group: Optional[dist_group_type] = None,
cp_global_ranks: List[int] = None, cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None, cp_stream: torch.cuda.Stream = None,
is_first_microbatch: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""fused attention fprop""" """fused attention fprop"""
...@@ -2164,9 +2756,9 @@ class FusedAttention(torch.nn.Module): ...@@ -2164,9 +2756,9 @@ class FusedAttention(torch.nn.Module):
!= tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
), 'No fused attention backend supports this input combination!' ), 'No fused attention backend supports this input combination!'
assert ( assert (
(query_layer.dtype in [torch.float16, torch.bfloat16]) (query_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
and (key_layer.dtype in [torch.float16, torch.bfloat16]) and (key_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
and (value_layer.dtype in [torch.float16, torch.bfloat16]) and (value_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
), 'FusedAttention only supports FP16 and BF16 data types.' ), 'FusedAttention only supports FP16 and BF16 data types.'
assert ( assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
...@@ -2248,7 +2840,22 @@ class FusedAttention(torch.nn.Module): ...@@ -2248,7 +2840,22 @@ class FusedAttention(torch.nn.Module):
if qkv_format == 'sbhd': if qkv_format == 'sbhd':
output = output.transpose(0,1).contiguous() output = output.transpose(0,1).contiguous()
else: else:
with self.prepare_forward(query_layer,
is_first_microbatch,
num_gemms=3,
allow_non_contiguous=True) as query_layer:
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
forced_fp8_dpa = ""
if self.fp8_meta["recipe"].fp8_mha:
if not self.fp8_meta["recipe"].fp8_dpa:
self.fp8_meta["recipe"].fp8_dpa = True
forced_fp8_dpa = " (forced)"
if _NVTE_DEBUG:
print("[DotProductAttention]: "
f"""using fp8_recipe.fp8_mha={self.fp8_meta["recipe"].fp8_mha}, """
f"""fp8_recipe.fp8_dpa={self.fp8_meta["recipe"].fp8_dpa}"""
f"""{forced_fp8_dpa} and """
f"""NVTE_FP8_DPA_BWD={int(os.getenv("NVTE_FP8_DPA_BWD", "1"))}""")
output = FusedAttnFunc.apply( output = FusedAttnFunc.apply(
self.training, self.training,
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv,
...@@ -2265,6 +2872,10 @@ class FusedAttention(torch.nn.Module): ...@@ -2265,6 +2872,10 @@ class FusedAttention(torch.nn.Module):
None, # rng_gen None, # rng_gen
fused_attention_backend, fused_attention_backend,
use_FAv2_bwd, use_FAv2_bwd,
self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
self.fp8_meta,
self.tp_size,
self.tp_group,
) )
# ...hd -> ...(hd) # ...hd -> ...(hd)
...@@ -2463,7 +3074,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -2463,7 +3074,9 @@ class DotProductAttention(torch.nn.Module):
attention_type=attention_type, attention_type=attention_type,
layer_number=layer_number, layer_number=layer_number,
deterministic=self.deterministic, deterministic=self.deterministic,
**attn_kwargs) **attn_kwargs,
tp_size=self.tp_size,
tp_group=self.tp_group)
self.unfused_attention = UnfusedDotProductAttention( self.unfused_attention = UnfusedDotProductAttention(
norm_factor, **attn_kwargs, layer_number=layer_number) norm_factor, **attn_kwargs, layer_number=layer_number)
...@@ -2532,6 +3145,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -2532,6 +3145,7 @@ class DotProductAttention(torch.nn.Module):
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
is_first_microbatch: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Dot Product Attention Layer. Dot Product Attention Layer.
...@@ -2635,6 +3249,19 @@ class DotProductAttention(torch.nn.Module): ...@@ -2635,6 +3249,19 @@ class DotProductAttention(torch.nn.Module):
Adjustments of the sequence_len_offset should be done after a complete forward pass. Adjustments of the sequence_len_offset should be done after a complete forward pass.
If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand. If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient. Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
""" """
assert ( assert (
...@@ -2746,6 +3373,12 @@ class DotProductAttention(torch.nn.Module): ...@@ -2746,6 +3373,12 @@ class DotProductAttention(torch.nn.Module):
), """Sequence lengths indicated by cu_seqlens_kv must be no greater than ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
the sequence dimention in 'key_layer' and 'value_layer'!""" the sequence dimention in 'key_layer' and 'value_layer'!"""
if (isinstance(query_layer, Float8Tensor)
and isinstance(key_layer, Float8Tensor)
and isinstance(value_layer, Float8Tensor)):
qkv_layout, query_layer._data, key_layer._data, value_layer._data = _get_qkv_layout(
query_layer._data, key_layer._data, value_layer._data, qkv_format = qkv_format)
else:
qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout( qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout(
query_layer, key_layer, value_layer, qkv_format = qkv_format) query_layer, key_layer, value_layer, qkv_format = qkv_format)
...@@ -2767,8 +3400,13 @@ class DotProductAttention(torch.nn.Module): ...@@ -2767,8 +3400,13 @@ class DotProductAttention(torch.nn.Module):
if (query_layer.dtype not in [torch.bfloat16, torch.float16] if (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16] or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16] or value_layer.dtype not in [torch.bfloat16, torch.float16]
or any(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer])
): ):
use_flash_attention = False use_flash_attention = False
if (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
):
use_fused_attention = False use_fused_attention = False
# Filter: Device and dimensions. # Filter: Device and dimensions.
...@@ -2865,8 +3503,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -2865,8 +3503,10 @@ class DotProductAttention(torch.nn.Module):
if use_fused_attention: if use_fused_attention:
fused_attention_backend = tex.get_fused_attn_backend( fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype], TE_DType[query_layer.dtype]
TE_DType[key_layer.dtype], if not isinstance(query_layer, Float8Tensor) else query_layer._fp8_dtype,
TE_DType[key_layer.dtype]
if not isinstance(key_layer, Float8Tensor) else key_layer._fp8_dtype,
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type], AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
...@@ -2879,7 +3519,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -2879,7 +3519,9 @@ class DotProductAttention(torch.nn.Module):
) )
# DPA does not support FP8; for FP8, use cpp_extensions modules directly # DPA does not support FP8; for FP8, use cpp_extensions modules directly
is_backend_avail = (fused_attention_backend in is_backend_avail = (fused_attention_backend in
[FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]]) [FusedAttnBackend["F16_max512_seqlen"],
FusedAttnBackend["F16_arbitrary_seqlen"],
FusedAttnBackend["FP8"]])
use_fused_attention = ( \ use_fused_attention = ( \
use_fused_attention and is_backend_avail and \ use_fused_attention and is_backend_avail and \
(not context_parallel or \ (not context_parallel or \
...@@ -2950,6 +3592,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -2950,6 +3592,8 @@ class DotProductAttention(torch.nn.Module):
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attention_mask=attention_mask, attention_mask=attention_mask,
fused_attention_backend=fused_attention_backend, fused_attention_backend=fused_attention_backend,
...@@ -2959,8 +3603,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -2959,8 +3603,7 @@ class DotProductAttention(torch.nn.Module):
cp_group=self.cp_group, cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks, cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream, cp_stream=self.cp_stream,
max_seqlen_q=max_seqlen_q, is_first_microbatch=is_first_microbatch)
max_seqlen_kv=max_seqlen_kv)
return self.fused_attention( return self.fused_attention(
query_layer, query_layer,
key_layer, key_layer,
...@@ -2968,6 +3611,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -2968,6 +3611,8 @@ class DotProductAttention(torch.nn.Module):
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attention_mask=attention_mask, attention_mask=attention_mask,
fused_attention_backend=fused_attention_backend, fused_attention_backend=fused_attention_backend,
...@@ -2977,8 +3622,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -2977,8 +3622,7 @@ class DotProductAttention(torch.nn.Module):
cp_group=self.cp_group, cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks, cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream, cp_stream=self.cp_stream,
max_seqlen_q=max_seqlen_q, is_first_microbatch=is_first_microbatch)
max_seqlen_kv=max_seqlen_kv)
assert (not context_parallel), \ assert (not context_parallel), \
"Context parallelism is only implemented with Flash Attention and Fused Attention!" "Context parallelism is only implemented with Flash Attention and Fused Attention!"
...@@ -3552,6 +4196,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3552,6 +4196,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_x_layer = self.qkv( mixed_x_layer = self.qkv(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=True, # specific to FP8 MHA
) )
num_queries_per_key_value = (self.num_attention_heads_per_partition // num_queries_per_key_value = (self.num_attention_heads_per_partition //
...@@ -3603,6 +4248,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3603,6 +4248,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_kv_layer = self.key_value( mixed_kv_layer = self.key_value(
encoder_output, encoder_output,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=True, # specific to FP8 MHA
) )
if self.qkv_weight_interleaved: if self.qkv_weight_interleaved:
...@@ -3633,6 +4279,9 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3633,6 +4279,9 @@ class MultiheadAttention(torch.nn.Module):
key_layer, value_layer = torch.split( key_layer, value_layer = torch.split(
mixed_kv_layer, mixed_kv_layer.shape[split_dim] // 2, dim = split_dim, mixed_kv_layer, mixed_kv_layer.shape[split_dim] // 2, dim = split_dim,
) )
key_layer, value_layer = (x.reshape(
x.size(0), x.size(1), -1, self.hidden_size_per_attention_head,
) for x in (key_layer, value_layer))
# Attention head [sq, b, h] --> [sq, b, hp] # Attention head [sq, b, h] --> [sq, b, hp]
if self.input_layernorm: if self.input_layernorm:
...@@ -3648,6 +4297,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3648,6 +4297,7 @@ class MultiheadAttention(torch.nn.Module):
query_layer = self.query_layer( query_layer = self.query_layer(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=True, # specific to FP8 MHA
) )
# [sq, b, hp] --> [sq, b, np, hn] # [sq, b, hp] --> [sq, b, np, hn]
...@@ -3662,6 +4312,9 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3662,6 +4312,9 @@ class MultiheadAttention(torch.nn.Module):
# ====================================================== # ======================================================
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
assert (not isinstance(query_layer, Float8Tensor)
and not isinstance(key_layer, Float8Tensor)
), "RoPE is not supported for Float8Tensors!"
# duplicate the pos_emb for self attention # duplicate the pos_emb for self attention
if not isinstance(rotary_pos_emb, tuple): if not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = ((rotary_pos_emb,) * 2) rotary_pos_emb = ((rotary_pos_emb,) * 2)
......
...@@ -84,6 +84,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -84,6 +84,7 @@ def fused_attn_fwd_qkvpacked(
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None, q_scale_o: torch.Tensor = None,
amax_s: torch.Tensor = None, amax_s: torch.Tensor = None,
...@@ -119,6 +120,8 @@ def fused_attn_fwd_qkvpacked( ...@@ -119,6 +120,8 @@ def fused_attn_fwd_qkvpacked(
shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_s: torch.Tensor, default = None q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_o: torch.Tensor, default = None q_scale_o: torch.Tensor, default = None
...@@ -206,6 +209,8 @@ def fused_attn_fwd_qkvpacked( ...@@ -206,6 +209,8 @@ def fused_attn_fwd_qkvpacked(
assert (d_scale_qkv is not None assert (d_scale_qkv is not None
), "d_scale_qkv is required as an input for FP8 fused attention." ), "d_scale_qkv is required as an input for FP8 fused attention."
assert (d_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_s is not None assert (q_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention." ), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_o is not None assert (q_scale_o is not None
...@@ -220,7 +225,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -220,7 +225,7 @@ def fused_attn_fwd_qkvpacked(
max_seqlen, is_training, attn_scale, dropout, fast_zero_fill, max_seqlen, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, qkv_dtype, cu_seqlens, qkv, qkv_dtype,
d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias,
rng_gen, rng_elts_per_thread, rng_gen, rng_elts_per_thread,
) )
...@@ -235,12 +240,14 @@ def fused_attn_bwd_qkvpacked( ...@@ -235,12 +240,14 @@ def fused_attn_bwd_qkvpacked(
o: torch.Tensor, o: torch.Tensor,
d_o: torch.Tensor, d_o: torch.Tensor,
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
d_scale_do: torch.Tensor = None, d_scale_do: torch.Tensor = None,
d_scale_dp: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
q_scale_dp: torch.Tensor = None, q_scale_dp: torch.Tensor = None,
q_scale_dqkv: torch.Tensor = None, q_scale_dqkv: torch.Tensor = None,
...@@ -272,6 +279,8 @@ def fused_attn_bwd_qkvpacked( ...@@ -272,6 +279,8 @@ def fused_attn_bwd_qkvpacked(
same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
qkv_dtype: tex.DType qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype data type of QKV; in tex.DType, not torch.dtype
dqkv_dtype: tex.DType
data type of dQKV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True, auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
...@@ -285,6 +294,8 @@ def fused_attn_bwd_qkvpacked( ...@@ -285,6 +294,8 @@ def fused_attn_bwd_qkvpacked(
input tensor for the dequantization of O in FP8 computations input tensor for the dequantization of O in FP8 computations
d_scale_do: torch.Tensor, default = None d_scale_do: torch.Tensor, default = None
input tensor for the dequantization of dO in FP8 computations input tensor for the dequantization of dO in FP8 computations
d_scale_dp: torch.Tensor, default = None
input tensor for the dequantization of dP in FP8 computations
q_scale_s: torch.Tensor, default = None q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations input tensor for the quantization of S in FP8 computations
q_scale_dp: torch.Tensor, default = None q_scale_dp: torch.Tensor, default = None
...@@ -336,6 +347,7 @@ def fused_attn_bwd_qkvpacked( ...@@ -336,6 +347,7 @@ def fused_attn_bwd_qkvpacked(
assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention." assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention."
assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention." assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention."
assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention." assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention."
assert (d_scale_dp is not None), "d_scale_dp is required for FP8 fused attention."
assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention." assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention."
assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention." assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention."
assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention." assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention."
...@@ -348,8 +360,8 @@ def fused_attn_bwd_qkvpacked( ...@@ -348,8 +360,8 @@ def fused_attn_bwd_qkvpacked(
output_tensors = tex.fused_attn_bwd_qkvpacked( output_tensors = tex.fused_attn_bwd_qkvpacked(
max_seqlen, attn_scale, dropout, fast_zero_fill, max_seqlen, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, o, d_o, qkv_dtype, aux_ctx_tensors, cu_seqlens, qkv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
) )
...@@ -368,6 +380,7 @@ def fused_attn_fwd_kvpacked( ...@@ -368,6 +380,7 @@ def fused_attn_fwd_kvpacked(
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None, q_scale_o: torch.Tensor = None,
amax_s: torch.Tensor = None, amax_s: torch.Tensor = None,
...@@ -410,6 +423,8 @@ def fused_attn_fwd_kvpacked( ...@@ -410,6 +423,8 @@ def fused_attn_fwd_kvpacked(
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_s: torch.Tensor, default = None q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_o: torch.Tensor, default = None q_scale_o: torch.Tensor, default = None
...@@ -496,12 +511,25 @@ def fused_attn_fwd_kvpacked( ...@@ -496,12 +511,25 @@ def fused_attn_fwd_kvpacked(
rng_elts_per_thread = (max_seqlen_q * max_seqlen_q rng_elts_per_thread = (max_seqlen_q * max_seqlen_q
+ BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA
assert (d_scale_qkv is not None
), "d_scale_qkv is required as an input for FP8 fused attention."
assert (d_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_o is not None
), "q_scale_o is required as an input for FP8 fused attention."
assert (amax_s is not None
), "amax_s is required as an input for FP8 fused attention."
assert (amax_o is not None
), "amax_o is required as an input for FP8 fused attention."
# execute kernel # execute kernel
output_tensors = tex.fused_attn_fwd_kvpacked( output_tensors = tex.fused_attn_fwd_kvpacked(
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype, cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype,
d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o,
attn_bias, rng_gen, rng_elts_per_thread, attn_bias, rng_gen, rng_elts_per_thread,
) )
...@@ -519,12 +547,14 @@ def fused_attn_bwd_kvpacked( ...@@ -519,12 +547,14 @@ def fused_attn_bwd_kvpacked(
o: torch.Tensor, o: torch.Tensor,
d_o: torch.Tensor, d_o: torch.Tensor,
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
d_scale_do: torch.Tensor = None, d_scale_do: torch.Tensor = None,
d_scale_dp: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
q_scale_dp: torch.Tensor = None, q_scale_dp: torch.Tensor = None,
q_scale_dqkv: torch.Tensor = None, q_scale_dqkv: torch.Tensor = None,
...@@ -562,7 +592,9 @@ def fused_attn_bwd_kvpacked( ...@@ -562,7 +592,9 @@ def fused_attn_bwd_kvpacked(
input tensor dO (gradient of O); input tensor dO (gradient of O);
same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details)
qkv_dtype: tex.DType qkv_dtype: tex.DType
data type of QKV; in tex.DType, not torch.dtype data type of Q and KV; in tex.DType, not torch.dtype
dqkv_dtype: tex.DType
data type of dQ and dKV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True, auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
...@@ -576,6 +608,8 @@ def fused_attn_bwd_kvpacked( ...@@ -576,6 +608,8 @@ def fused_attn_bwd_kvpacked(
input tensor for the dequantization of O in FP8 computations input tensor for the dequantization of O in FP8 computations
d_scale_do: torch.Tensor, default = None d_scale_do: torch.Tensor, default = None
input tensor for the dequantization of dO in FP8 computations input tensor for the dequantization of dO in FP8 computations
d_scale_dp: torch.Tensor, default = None
input tensor for the dequantization of dP in FP8 computations
q_scale_s: torch.Tensor, default = None q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations input tensor for the quantization of S in FP8 computations
q_scale_dp: torch.Tensor, default = None q_scale_dp: torch.Tensor, default = None
...@@ -631,6 +665,7 @@ def fused_attn_bwd_kvpacked( ...@@ -631,6 +665,7 @@ def fused_attn_bwd_kvpacked(
assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention." assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention."
assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention." assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention."
assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention." assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention."
assert (d_scale_dp is not None), "d_scale_dp is required for FP8 fused attention."
assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention." assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention."
assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention." assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention."
assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention." assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention."
...@@ -643,8 +678,8 @@ def fused_attn_bwd_kvpacked( ...@@ -643,8 +678,8 @@ def fused_attn_bwd_kvpacked(
output_tensors = tex.fused_attn_bwd_kvpacked( output_tensors = tex.fused_attn_bwd_kvpacked(
max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, aux_ctx_tensors, cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
) )
...@@ -664,6 +699,7 @@ def fused_attn_fwd( ...@@ -664,6 +699,7 @@ def fused_attn_fwd(
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
q_scale_o: torch.Tensor = None, q_scale_o: torch.Tensor = None,
amax_s: torch.Tensor = None, amax_s: torch.Tensor = None,
...@@ -710,6 +746,8 @@ def fused_attn_fwd( ...@@ -710,6 +746,8 @@ def fused_attn_fwd(
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None
input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_s: torch.Tensor, default = None q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_o: torch.Tensor, default = None q_scale_o: torch.Tensor, default = None
...@@ -798,12 +836,25 @@ def fused_attn_fwd( ...@@ -798,12 +836,25 @@ def fused_attn_fwd(
rng_elts_per_thread = (max_seqlen_q * max_seqlen_q rng_elts_per_thread = (max_seqlen_q * max_seqlen_q
+ BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA
assert (d_scale_qkv is not None
), "d_scale_qkv is required as an input for FP8 fused attention."
assert (d_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_s is not None
), "q_scale_s is required as an input for FP8 fused attention."
assert (q_scale_o is not None
), "q_scale_o is required as an input for FP8 fused attention."
assert (amax_s is not None
), "amax_s is required as an input for FP8 fused attention."
assert (amax_o is not None
), "amax_o is required as an input for FP8 fused attention."
# execute kernel # execute kernel
output_tensors = tex.fused_attn_fwd( output_tensors = tex.fused_attn_fwd(
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype, cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype,
d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o,
attn_bias, rng_gen, rng_elts_per_thread, attn_bias, rng_gen, rng_elts_per_thread,
) )
...@@ -822,12 +873,14 @@ def fused_attn_bwd( ...@@ -822,12 +873,14 @@ def fused_attn_bwd(
o: torch.Tensor, o: torch.Tensor,
d_o: torch.Tensor, d_o: torch.Tensor,
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
d_scale_do: torch.Tensor = None, d_scale_do: torch.Tensor = None,
d_scale_dp: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
q_scale_dp: torch.Tensor = None, q_scale_dp: torch.Tensor = None,
q_scale_dqkv: torch.Tensor = None, q_scale_dqkv: torch.Tensor = None,
...@@ -869,6 +922,8 @@ def fused_attn_bwd( ...@@ -869,6 +922,8 @@ def fused_attn_bwd(
same shape as Q same shape as Q
qkv_dtype: tex.DType qkv_dtype: tex.DType
data type of Q, K and V; in tex.DType, not torch.dtype data type of Q, K and V; in tex.DType, not torch.dtype
dqkv_dtype: tex.DType
data type of dQ, dK and dV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors: List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True, auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
...@@ -882,6 +937,8 @@ def fused_attn_bwd( ...@@ -882,6 +937,8 @@ def fused_attn_bwd(
input tensor for the dequantization of O in FP8 computations input tensor for the dequantization of O in FP8 computations
d_scale_do: torch.Tensor, default = None d_scale_do: torch.Tensor, default = None
input tensor for the dequantization of dO in FP8 computations input tensor for the dequantization of dO in FP8 computations
d_scale_dp: torch.Tensor, default = None
input tensor for the dequantization of dP in FP8 computations
q_scale_s: torch.Tensor, default = None q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations input tensor for the quantization of S in FP8 computations
q_scale_dp: torch.Tensor, default = None q_scale_dp: torch.Tensor, default = None
...@@ -941,6 +998,7 @@ def fused_attn_bwd( ...@@ -941,6 +998,7 @@ def fused_attn_bwd(
assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention." assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention."
assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention." assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention."
assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention." assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention."
assert (d_scale_dp is not None), "d_scale_dp is required for FP8 fused attention."
assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention." assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention."
assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention." assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention."
assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention." assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention."
...@@ -953,8 +1011,8 @@ def fused_attn_bwd( ...@@ -953,8 +1011,8 @@ def fused_attn_bwd(
output_tensors = tex.fused_attn_bwd( output_tensors = tex.fused_attn_bwd(
max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, aux_ctx_tensors, cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
) )
......
...@@ -786,9 +786,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -786,9 +786,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Get communication and GEMM output chunk sizes // Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const bool do_gelu = pre_gelu_out.numel() > 0; const bool do_gelu = pre_gelu_out.numel() > 0;
const int output_chunk_bytes = (do_gelu const int output_chunk_bytes = (n_chunk * m) * D.element_size();
? (n_chunk * m) * D.element_size()
: (n_chunk * m) * HALF_BYTES);
const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0; const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0;
// Get output and workspace data pointers // Get output and workspace data pointers
......
...@@ -32,6 +32,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -32,6 +32,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
const at::Tensor QKV, const at::Tensor QKV,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S,
...@@ -51,11 +52,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -51,11 +52,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
const at::Tensor O, const at::Tensor O,
const at::Tensor dO, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, const c10::optional<at::Tensor> scale_dQKV,
...@@ -74,6 +77,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -74,6 +77,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
const at::Tensor KV, const at::Tensor KV,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S,
...@@ -95,11 +99,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -95,11 +99,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const at::Tensor O, const at::Tensor O,
const at::Tensor dO, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, const c10::optional<at::Tensor> scale_dQKV,
...@@ -119,6 +125,7 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -119,6 +125,7 @@ std::vector<at::Tensor> fused_attn_fwd(
const at::Tensor V, const at::Tensor V,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S,
...@@ -141,11 +148,13 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -141,11 +148,13 @@ std::vector<at::Tensor> fused_attn_bwd(
const at::Tensor O, const at::Tensor O,
const at::Tensor dO, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, const c10::optional<at::Tensor> scale_dQKV,
......
...@@ -97,6 +97,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -97,6 +97,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
const at::Tensor QKV, const at::Tensor QKV,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S,
...@@ -126,22 +127,24 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -126,22 +127,24 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
// FP8 // FP8
auto h = q_shape[q_shape.size() - 2]; auto h = q_shape[q_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1]; auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0)) { if (set_zero
&& ((h * d) % block_size == 0)
&& (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else { } else {
O.fill_(0); O.fill_(0);
} }
if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value()) if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!scale_S.has_value()) || (!scale_O.has_value())
|| (!amax_S.has_value()) || (!amax_O.has_value())) { || (!amax_S.has_value()) || (!amax_O.has_value())) {
std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O ";
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
} }
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
at::Tensor descale_S = torch::empty_like(scale_S.value());
te_S = makeTransformerEngineTensor(nullptr, {0}, te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_S.value().data_ptr(), DType::kFloat32, amax_S.value().data_ptr(),
scale_S.value().data_ptr(), descale_S.data_ptr()); scale_S.value().data_ptr(), descale_S.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
...@@ -261,11 +264,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -261,11 +264,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
const at::Tensor O, const at::Tensor O,
const at::Tensor dO, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, const c10::optional<at::Tensor> scale_dQKV,
...@@ -284,26 +289,29 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -284,26 +289,29 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
auto h = q_shape[q_shape.size() - 2]; auto h = q_shape[q_shape.size() - 2];
// create output tensor dQKV // create output tensor dQKV
at::Tensor dQKV = torch::empty_like(QKV); auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA);
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); at::Tensor dQKV = torch::empty_like(QKV, options);
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV; TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
auto d = q_shape[q_shape.size() - 1]; auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0)) { if (set_zero
&& ((h * d) % block_size == 0)
&& (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else { } else {
dQKV.fill_(0); dQKV.fill_(0);
} }
if ((!descale_QKV.has_value()) || (!descale_S.has_value()) if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!descale_O.has_value()) || (!descale_dO.has_value()) || (!descale_O.has_value()) || (!descale_dO.has_value())
|| (!scale_S.has_value()) || (!scale_dP.has_value()) || (!descale_dP.has_value()) || (!scale_S.has_value())
|| (!scale_dQKV.has_value()) || (!scale_dP.has_value()) || (!scale_dQKV.has_value())
|| (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) {
std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, "; std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, ";
err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV"); err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, ");
err_tensors = err_tensors + std::string("amax_dP and amax_dQKV ");
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
} }
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape,
...@@ -311,14 +319,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -311,14 +319,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); qkv_type, nullptr, nullptr, descale_O.value().data_ptr());
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); dqkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32,
nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr()); nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr());
at::Tensor descale_dP = torch::empty_like(scale_dP.value());
te_dP = makeTransformerEngineTensor(nullptr, {0}, te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), DType::kFloat32, amax_dP.value().data_ptr(),
descale_dP.data_ptr()); scale_dP.value().data_ptr(), descale_dP.value().data_ptr());
te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, qkv_type, te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, dqkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16 // BF16 or FP16
...@@ -327,13 +334,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -327,13 +334,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr); dqkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0}, te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr); DType::kFloat32, nullptr, nullptr, nullptr);
te_dP = makeTransformerEngineTensor(nullptr, {0}, te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr); DType::kFloat32, nullptr, nullptr, nullptr);
te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape,
qkv_type, nullptr, nullptr, nullptr); dqkv_type, nullptr, nullptr, nullptr);
} else { } else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
} }
...@@ -433,6 +440,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -433,6 +440,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
const at::Tensor KV, const at::Tensor KV,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S,
...@@ -458,24 +466,26 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -458,24 +466,26 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
// FP8 // FP8
auto h = q_shape[q_shape.size() - 2]; auto h = q_shape[q_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1]; auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0)) { if (set_zero
&& ((h * d) % block_size == 0)
&& (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else { } else {
O.fill_(0); O.fill_(0);
} }
if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value()) if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!scale_S.has_value()) || (!scale_O.has_value())
|| (!amax_S.has_value()) || (!amax_O.has_value())) { || (!amax_S.has_value()) || (!amax_O.has_value())) {
std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O ";
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
} }
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
at::Tensor descale_S = torch::empty_like(scale_S.value());
te_S = makeTransformerEngineTensor(nullptr, {0}, te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_S.value().data_ptr(), DType::kFloat32, amax_S.value().data_ptr(),
scale_S.value().data_ptr(), descale_S.data_ptr()); scale_S.value().data_ptr(), descale_S.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
...@@ -608,11 +618,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -608,11 +618,13 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const at::Tensor O, const at::Tensor O,
const at::Tensor dO, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, const c10::optional<at::Tensor> scale_dQKV,
...@@ -635,15 +647,18 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -635,15 +647,18 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
auto d = q_shape[q_shape.size() - 1]; auto d = q_shape[q_shape.size() - 1];
// create output tensors dQ and dKV // create output tensors dQ and dKV
at::Tensor dQ = torch::empty_like(Q); auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA);
at::Tensor dKV = torch::empty_like(KV); at::Tensor dQ = torch::empty_like(Q, options);
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); at::Tensor dKV = torch::empty_like(KV, options);
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV; TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
if (set_zero && ((h_q * d)% block_size == 0) && ((h_kv * d)% block_size == 0)) { if (set_zero
&& ((h_q * d)% block_size == 0)
&& ((h_kv * d)% block_size == 0)
&& (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else { } else {
...@@ -652,11 +667,12 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -652,11 +667,12 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
} }
if ((!descale_QKV.has_value()) || (!descale_S.has_value()) if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!descale_O.has_value()) || (!descale_dO.has_value()) || (!descale_O.has_value()) || (!descale_dO.has_value())
|| (!scale_S.has_value()) || (!scale_dP.has_value()) || (!descale_dP.has_value()) || (!scale_S.has_value())
|| (!scale_dQKV.has_value()) || (!scale_dP.has_value()) || (!scale_dQKV.has_value())
|| (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) {
std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, "; std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, ";
err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV"); err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, ");
err_tensors = err_tensors + std::string("amax_dP and amax_dQKV ");
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
} }
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
...@@ -666,16 +682,15 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -666,16 +682,15 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); qkv_type, nullptr, nullptr, descale_O.value().data_ptr());
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); dqkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr,
scale_S.value().data_ptr(), descale_S.value().data_ptr()); scale_S.value().data_ptr(), descale_S.value().data_ptr());
at::Tensor descale_dP = torch::empty_like(scale_dP.value());
te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32,
amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), amax_dP.value().data_ptr(), scale_dP.value().data_ptr(),
descale_dP.data_ptr()); descale_dP.value().data_ptr());
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, qkv_type, te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, qkv_type, te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, dqkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16 // BF16 or FP16
...@@ -686,15 +701,15 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -686,15 +701,15 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr); dqkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0}, te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr); DType::kFloat32, nullptr, nullptr, nullptr);
te_dP = makeTransformerEngineTensor(nullptr, {0}, te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr); DType::kFloat32, nullptr, nullptr, nullptr);
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr); dqkv_type, nullptr, nullptr, nullptr);
te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape,
qkv_type, nullptr, nullptr, nullptr); dqkv_type, nullptr, nullptr, nullptr);
} else { } else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
} }
...@@ -806,6 +821,7 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -806,6 +821,7 @@ std::vector<at::Tensor> fused_attn_fwd(
const at::Tensor V, const at::Tensor V,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S,
...@@ -832,14 +848,17 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -832,14 +848,17 @@ std::vector<at::Tensor> fused_attn_fwd(
// FP8 // FP8
auto h = q_shape[q_shape.size() - 2]; auto h = q_shape[q_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1]; auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0)) { if (set_zero
&& ((h * d) % block_size == 0)
&& (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else { } else {
O.fill_(0); O.fill_(0);
} }
if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value()) if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!scale_S.has_value()) || (!scale_O.has_value())
|| (!amax_S.has_value()) || (!amax_O.has_value())) { || (!amax_S.has_value()) || (!amax_O.has_value())) {
std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O ";
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
} }
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
...@@ -848,10 +867,9 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -848,10 +867,9 @@ std::vector<at::Tensor> fused_attn_fwd(
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape,
qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr());
at::Tensor descale_S = torch::empty_like(scale_S.value());
te_S = makeTransformerEngineTensor(nullptr, {0}, te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, amax_S.value().data_ptr(), DType::kFloat32, amax_S.value().data_ptr(),
scale_S.value().data_ptr(), descale_S.data_ptr()); scale_S.value().data_ptr(), descale_S.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
...@@ -990,11 +1008,13 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -990,11 +1008,13 @@ std::vector<at::Tensor> fused_attn_bwd(
const at::Tensor O, const at::Tensor O,
const at::Tensor dO, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, const c10::optional<at::Tensor> scale_dQKV,
...@@ -1011,7 +1031,7 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1011,7 +1031,7 @@ std::vector<at::Tensor> fused_attn_bwd(
auto h_q = q_shape[q_shape.size() - 2]; auto h_q = q_shape[q_shape.size() - 2];
auto h_kv = k_shape[k_shape.size() - 2]; auto h_kv = k_shape[k_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1]; auto d = q_shape[q_shape.size() - 1];
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA);
at::Tensor dQ; at::Tensor dQ;
at::Tensor dK; at::Tensor dK;
...@@ -1046,7 +1066,7 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1046,7 +1066,7 @@ std::vector<at::Tensor> fused_attn_bwd(
torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2); torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2);
break; break;
case NVTE_QKV_Layout_Group::NVTE_HD_2HD: case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
dQ = torch::empty_like(Q); dQ = torch::empty_like(Q, options);
tmp_shape = std::vector<int64_t>{k_sizes.begin(), k_sizes.end()}; tmp_shape = std::vector<int64_t>{k_sizes.begin(), k_sizes.end()};
tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2)); tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2));
dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options);
...@@ -1058,7 +1078,7 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1058,7 +1078,7 @@ std::vector<at::Tensor> fused_attn_bwd(
torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3); torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3);
break; break;
case NVTE_QKV_Layout_Group::NVTE_HD_H2D: case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
dQ = torch::empty_like(Q); dQ = torch::empty_like(Q, options);
tmp_shape = std::vector<int64_t>{k_sizes.begin(), k_sizes.end()}; tmp_shape = std::vector<int64_t>{k_sizes.begin(), k_sizes.end()};
tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2)); tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2));
dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options);
...@@ -1068,9 +1088,9 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1068,9 +1088,9 @@ std::vector<at::Tensor> fused_attn_bwd(
torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2); torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2);
break; break;
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
dQ = torch::empty_like(Q); dQ = torch::empty_like(Q, options);
dK = torch::empty_like(K); dK = torch::empty_like(K, options);
dV = torch::empty_like(V); dV = torch::empty_like(V, options);
break; break;
default: default:
NVTE_ERROR("QKV layout not supported!"); NVTE_ERROR("QKV layout not supported!");
...@@ -1085,7 +1105,8 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1085,7 +1105,8 @@ std::vector<at::Tensor> fused_attn_bwd(
&& ((h_kv * d) % block_size == 0) && ((h_kv * d) % block_size == 0)
&& dQ.is_contiguous() && dQ.is_contiguous()
&& dK.is_contiguous() && dK.is_contiguous()
&& dV.is_contiguous()) { && dV.is_contiguous()
&& (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
...@@ -1096,11 +1117,12 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1096,11 +1117,12 @@ std::vector<at::Tensor> fused_attn_bwd(
} }
if ((!descale_QKV.has_value()) || (!descale_S.has_value()) if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!descale_O.has_value()) || (!descale_dO.has_value()) || (!descale_O.has_value()) || (!descale_dO.has_value())
|| (!scale_S.has_value()) || (!scale_dP.has_value()) || (!descale_dP.has_value()) || (!scale_S.has_value())
|| (!scale_dQKV.has_value()) || (!scale_dP.has_value()) || (!scale_dQKV.has_value())
|| (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) {
std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, "; std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, ";
err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV"); err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, ");
err_tensors = err_tensors + std::string("amax_dP and amax_dQKV ");
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
} }
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape,
...@@ -1112,18 +1134,17 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1112,18 +1134,17 @@ std::vector<at::Tensor> fused_attn_bwd(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); qkv_type, nullptr, nullptr, descale_O.value().data_ptr());
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); dqkv_type, nullptr, nullptr, descale_dO.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr,
scale_S.value().data_ptr(), descale_S.value().data_ptr()); scale_S.value().data_ptr(), descale_S.value().data_ptr());
at::Tensor descale_dP = torch::empty_like(scale_dP.value());
te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32,
amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), amax_dP.value().data_ptr(), scale_dP.value().data_ptr(),
descale_dP.data_ptr()); descale_dP.value().data_ptr());
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, qkv_type, te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
te_dK = makeTransformerEngineTensor(dK.data_ptr(), k_shape, qkv_type, te_dK = makeTransformerEngineTensor(dK.data_ptr(), k_shape, dqkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
te_dV = makeTransformerEngineTensor(dV.data_ptr(), v_shape, qkv_type, te_dV = makeTransformerEngineTensor(dV.data_ptr(), v_shape, dqkv_type,
amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
// BF16 or FP16 // BF16 or FP16
...@@ -1136,17 +1157,17 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1136,17 +1157,17 @@ std::vector<at::Tensor> fused_attn_bwd(
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr); qkv_type, nullptr, nullptr, nullptr);
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr); dqkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0}, te_S = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr); DType::kFloat32, nullptr, nullptr, nullptr);
te_dP = makeTransformerEngineTensor(nullptr, {0}, te_dP = makeTransformerEngineTensor(nullptr, {0},
DType::kFloat32, nullptr, nullptr, nullptr); DType::kFloat32, nullptr, nullptr, nullptr);
te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape,
qkv_type, nullptr, nullptr, nullptr); dqkv_type, nullptr, nullptr, nullptr);
te_dK = makeTransformerEngineTensor(dK.data_ptr(), k_shape, te_dK = makeTransformerEngineTensor(dK.data_ptr(), k_shape,
qkv_type, nullptr, nullptr, nullptr); dqkv_type, nullptr, nullptr, nullptr);
te_dV = makeTransformerEngineTensor(dV.data_ptr(), v_shape, te_dV = makeTransformerEngineTensor(dV.data_ptr(), v_shape,
qkv_type, nullptr, nullptr, nullptr); dqkv_type, nullptr, nullptr, nullptr);
} else { } else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
} }
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Tensor class with FP8 data""" """Tensor class with FP8 data"""
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
...@@ -233,6 +233,87 @@ class _IdentityFunc(torch.autograd.Function): ...@@ -233,6 +233,87 @@ class _IdentityFunc(torch.autograd.Function):
def backward(ctx, grad): def backward(ctx, grad):
return grad.to(ctx.input_dtype), None return grad.to(ctx.input_dtype), None
class _ViewFunc(torch.autograd.Function):
"""View function
View the Float8Tensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
shape: Tuple[int] = None,
) -> torch.Tensor:
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
return tensor
# Construct new tensor if shape is provided
if isinstance(tensor, Float8Tensor):
return Float8Tensor.make_like(
tensor,
data=tensor._data.view(*shape),
)
return tensor.view(*shape)
@staticmethod
def backward(ctx,
grad: torch.Tensor,
) -> Tuple[[torch.Tensor, None], ...]:
if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like(
grad,
data=grad._data.view(ctx.shape),
)
return dgrad, None
return grad.view(ctx.shape), None
class _ReshapeFunc(torch.autograd.Function):
"""Reshape function
Reshape the Float8Tensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
shape: Tuple[int] = None,
) -> torch.Tensor:
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
return tensor
# Construct new tensor if shape is provided
if isinstance(tensor, Float8Tensor):
return Float8Tensor.make_like(
tensor,
data=tensor._data.reshape(*shape),
)
return tensor.reshape(*shape)
@staticmethod
def backward(ctx,
grad: torch.Tensor,
) -> Tuple[Union[torch.Tensor, None], ...]:
if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like(
grad,
data=grad._data.reshape(ctx.shape),
)
return dgrad, None
return grad.reshape(ctx.shape), None
class Float8Tensor(torch.Tensor): class Float8Tensor(torch.Tensor):
"""Experimental tensor class with FP8 data """Experimental tensor class with FP8 data
...@@ -453,6 +534,12 @@ class Float8Tensor(torch.Tensor): ...@@ -453,6 +534,12 @@ class Float8Tensor(torch.Tensor):
def clone(self) -> Float8Tensor: def clone(self) -> Float8Tensor:
return _IdentityFunc.apply(self, {"data": self._data.detach().clone()}) return _IdentityFunc.apply(self, {"data": self._data.detach().clone()})
def view(self, *shape: Tuple[int]) -> Float8Tensor:
return _ViewFunc.apply(self, shape)
def reshape(self, *shape: Tuple[int]) -> Float8Tensor:
return _ReshapeFunc.apply(self, shape)
def expand_as(self, other: torch.Tensor): def expand_as(self, other: torch.Tensor):
if other is self: if other is self:
# Note: expand_as is hackily used to create dummy autograd nodes # Note: expand_as is hackily used to create dummy autograd nodes
......
...@@ -202,6 +202,11 @@ class FP8GlobalStateManager: ...@@ -202,6 +202,11 @@ class FP8GlobalStateManager:
# `fp8_param_to_autocast`. This is used for keeping track of FP8 weights # `fp8_param_to_autocast`. This is used for keeping track of FP8 weights
# in an autocasted region and cross reference them in `float8_tensor.py` # in an autocasted region and cross reference them in `float8_tensor.py`
# to perform the forward amax reduction. # to perform the forward amax reduction.
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
if fp8_meta_tensor_key not in fp8_meta:
# Handles non-parameter FP8 modules, e.g. DPA.
continue
if forward and fp8_weights is not None: if forward and fp8_weights is not None:
autocast_key = cls.get_unique_autocast_key( autocast_key = cls.get_unique_autocast_key(
fp8_meta["recipe"], fp8_meta["fp8_group"]) fp8_meta["recipe"], fp8_meta["fp8_group"])
...@@ -217,7 +222,6 @@ class FP8GlobalStateManager: ...@@ -217,7 +222,6 @@ class FP8GlobalStateManager:
key = cls.get_key_in_buffer( key = cls.get_key_in_buffer(
forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"]) forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"])
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
if key not in cls.global_amax_buffer: if key not in cls.global_amax_buffer:
cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
......
...@@ -268,6 +268,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -268,6 +268,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",) fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)
for meta_key in fp8_meta_tensor_keys: for meta_key in fp8_meta_tensor_keys:
if meta_key not in self.fp8_meta:
# Handles non-parameter FP8 modules, e.g. DPA.
continue
curr_len = self.fp8_meta[meta_key].amax_history.shape[0] curr_len = self.fp8_meta[meta_key].amax_history.shape[0]
if length == curr_len: if length == curr_len:
continue continue
...@@ -568,6 +571,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -568,6 +571,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
inp: torch.Tensor, inp: torch.Tensor,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
num_gemms: int = 1, num_gemms: int = 1,
allow_non_contiguous: bool = False,
) -> Generator[torch.Tensor, None, None]: ) -> Generator[torch.Tensor, None, None]:
"""Checks and prep for FWD. """Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know The context manager is needed because there isn't a way for a module to know
...@@ -610,7 +614,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -610,7 +614,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
if not allow_non_contiguous:
yield inp.contiguous() yield inp.contiguous()
else:
yield inp
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
...@@ -645,8 +652,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -645,8 +652,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
R4: bias gradient on R1. R4: bias gradient on R1.
""" """
if isinstance(grad_output, Float8Tensor):
grad_output._data = grad_output._data.contiguous()
else:
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
grad_output_mat = grad_output.view((-1, grad_output.shape[-1])) grad_output_mat = grad_output.view(-1, grad_output.shape[-1])
gather_grad_output = row_parallel_mode and ctx.sequence_parallel gather_grad_output = row_parallel_mode and ctx.sequence_parallel
# No-FP8 case: bgrad is fused with wgrad for this case. # No-FP8 case: bgrad is fused with wgrad for this case.
...@@ -684,6 +694,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -684,6 +694,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0) grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0)
else: else:
grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8) grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8)
if not isinstance(grad_output_mat, Float8Tensor):
cast_to_fp8( cast_to_fp8(
grad_output_mat, grad_output_mat,
ctx.fp8_meta["scaling_bwd"], ctx.fp8_meta["scaling_bwd"],
...@@ -691,9 +702,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -691,9 +702,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fp8_dtype_backward, fp8_dtype_backward,
out=grad_output_c, out=grad_output_c,
) )
else:
grad_output_c = grad_ouput_mat # pylint: disable=undefined-variable
if not ctx.ub_overlap_ag: if not ctx.ub_overlap_ag:
grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
if not isinstance(grad_output_c, Float8Tensor):
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
else:
grad_output_t = grad_output_c.transpose_2d()
else: else:
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1) grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1)
grad_output_t = None grad_output_t = None
...@@ -702,14 +718,21 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -702,14 +718,21 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# FP8 case without gather: cast, transpose, bgrad fused # FP8 case without gather: cast, transpose, bgrad fused
if ctx.use_bias: if ctx.use_bias:
grad_output_mat_no_fp8 = grad_output_mat
if isinstance(grad_output_mat, Float8Tensor):
grad_output_mat_no_fp8 = grad_output_mat.from_float8(grad_output_mat.dtype)
grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused( grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused(
grad_output_mat, grad_output_mat_no_fp8,
ctx.fp8_meta["scaling_bwd"], ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1, tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward, fp8_dtype_backward,
) )
else: else:
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if isinstance(grad_output_mat, Float8Tensor):
grad_output_c = grad_output_mat
grad_output_t = grad_output_c.transpose_2d()
else:
grad_output_c, grad_output_t = fp8_cast_transpose_fused( grad_output_c, grad_output_t = fp8_cast_transpose_fused(
grad_output_mat, grad_output_mat,
ctx.fp8_meta["scaling_bwd"], ctx.fp8_meta["scaling_bwd"],
...@@ -718,12 +741,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -718,12 +741,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
) )
else: else:
grad_output_t = None grad_output_t = None
if not isinstance(grad_output_mat, Float8Tensor):
grad_output_c = cast_to_fp8( grad_output_c = cast_to_fp8(
grad_output_mat, grad_output_mat,
ctx.fp8_meta["scaling_bwd"], ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1, tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward, fp8_dtype_backward,
) )
else:
grad_output_c = grad_output_mat
grad_bias = None grad_bias = None
return grad_output_mat, grad_output_c, grad_output_t, grad_bias return grad_output_mat, grad_output_c, grad_output_t, grad_bias
......
...@@ -43,6 +43,7 @@ from ..jit import no_torch_dynamo ...@@ -43,6 +43,7 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ._common import _apply_normalization, _noop_cat from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
__all__ = ["LayerNormLinear"] __all__ = ["LayerNormLinear"]
...@@ -190,6 +191,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -190,6 +191,9 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out = ln_out_total ln_out = ln_out_total
if fp8: if fp8:
if _NVTE_DEBUG:
print('[LayerNormLinear]: using FP8 forward')
bias_dtype = ( bias_dtype = (
torch.bfloat16 torch.bfloat16
if activation_dtype == torch.float32 if activation_dtype == torch.float32
...@@ -230,6 +234,15 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -230,6 +234,15 @@ class _LayerNormLinear(torch.autograd.Function):
) )
weight_t_fp8 = None weight_t_fp8 = None
if fp8_meta["recipe"].fp8_mha:
out_index, meta_tensor, output_te_dtype, output_dtype = (
tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_meta["scaling_fwd"],
fp8_dtype_forward,
torch.uint8)
else:
out_index, meta_tensor, output_te_dtype, output_dtype = (
None, None, None, activation_dtype)
out, _ = tex.fp8_gemm( out, _ = tex.fp8_gemm(
weight_fp8._data, weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
...@@ -239,7 +252,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -239,7 +252,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
activation_dtype, output_dtype,
get_workspace(), get_workspace(),
bias=bias, bias=bias,
use_bias=use_bias, use_bias=use_bias,
...@@ -247,8 +260,22 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -247,8 +260,22 @@ class _LayerNormLinear(torch.autograd.Function):
ub_algo=ub_algo if ub_overlap_ag else None, ub_algo=ub_algo if ub_overlap_ag else None,
ub=ub_obj_lnout if ub_overlap_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None,
out_index=out_index,
fp8_meta_tensor=meta_tensor,
D_dtype=output_te_dtype,
)
if output_dtype == torch.uint8:
out = Float8Tensor(data=out,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_dtype=fp8_dtype_forward,
dtype=activation_dtype,
) )
else: else:
if _NVTE_DEBUG:
print('[LayerNormLinear]: using non-FP8 forward')
# Cast for native AMP # Cast for native AMP
weight = cast_if_needed(weight, activation_dtype) weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
...@@ -338,7 +365,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -338,7 +365,6 @@ class _LayerNormLinear(torch.autograd.Function):
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp.shape[1:-1], out.shape[-1]) out = out.view(-1, *inp.shape[1:-1], out.shape[-1])
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
shape = list(inp.shape) shape = list(inp.shape)
...@@ -352,6 +378,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -352,6 +378,10 @@ class _LayerNormLinear(torch.autograd.Function):
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
if isinstance(grad_outputs[0], Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[0]._scale_inv
with torch.cuda.nvtx.range("_LayerNormLinear_backward"): with torch.cuda.nvtx.range("_LayerNormLinear_backward"):
( (
inputmat, inputmat,
...@@ -465,6 +495,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -465,6 +495,9 @@ class _LayerNormLinear(torch.autograd.Function):
ub_obj = None ub_obj = None
if ctx.fp8: if ctx.fp8:
if _NVTE_DEBUG:
print('[LayerNormLinear]: using FP8 backward')
fp8_dtype_forward = get_fp8_te_dtype( fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True ctx.fp8_meta["recipe"], fprop_tensor=True
) )
...@@ -486,7 +519,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -486,7 +519,8 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
grad_output_c, grad_output_c._data
if isinstance(grad_output_c, Float8Tensor) else grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1, tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward, fp8_dtype_backward,
...@@ -503,6 +537,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -503,6 +537,9 @@ class _LayerNormLinear(torch.autograd.Function):
) )
clear_tensor_data(grad_output_c) clear_tensor_data(grad_output_c)
else: else:
if _NVTE_DEBUG:
print('[LayerNormLinear]: using non-FP8 backward')
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
_, _, _ = tex.gemm( _, _, _ = tex.gemm(
weight, weight,
...@@ -551,7 +588,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -551,7 +588,8 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
grad_output_t, grad_output_t._data
if isinstance(grad_output_t, Float8Tensor) else grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1, tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward, fp8_dtype_backward,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment